monwork - readd stringToShellFields

This function was needed as bytes.Fields had some problems with the
quoting.
Now there are test cases too so that errors can be found more easily.
This commit is contained in:
Gibheer 2018-12-14 14:02:58 +01:00
parent ff79584084
commit 10fb89a017
2 changed files with 75 additions and 1 deletions

View File

@ -167,7 +167,7 @@ func startConfigGen(db *sql.DB, checkInterval time.Duration) {
time.Sleep(checkInterval) time.Sleep(checkInterval)
continue continue
} }
if _, err := tx.Exec(SQLRefreshActiveCheck, check_id, pq.Array(bytes.Fields(cmd.Bytes()))); err != nil { if _, err := tx.Exec(SQLRefreshActiveCheck, check_id, pq.Array(stringToShellFields(cmd.Bytes()))); err != nil {
tx.Rollback() tx.Rollback()
log.Printf("could not refresh check '%d': %s", check_id, err) log.Printf("could not refresh check '%d': %s", check_id, err)
continue continue
@ -184,6 +184,41 @@ func startConfigGen(db *sql.DB, checkInterval time.Duration) {
} }
} }
func stringToShellFields(in []byte) [][]byte {
if len(in) == 0 {
return [][]byte{}
}
fields := bytes.Fields(in)
result := [][]byte{}
var quote byte
for _, field := range fields {
if quote == 0 && (field[0] != '\'' && field[0] != '"') {
result = append(result, field)
continue
}
if quote == 0 && (field[0] == '\'' || field[0] == '"') {
quote = field[0]
if field[len(field)-1] == quote {
result = append(result, field[1:len(field)-1])
quote = 0
continue
}
result = append(result, field[1:])
continue
}
idx := len(result) - 1
if bytes.HasSuffix(field, []byte{quote}) {
result[idx] = append(result[idx], append([]byte(" "), field[:len(field)-1]...)...)
quote = 0
continue
}
result[idx] = append(result[idx], append([]byte(" "), field...)...)
}
return result
}
var ( var (
SQLGetConfigUpdates = `select c.id, co.command, c.options SQLGetConfigUpdates = `select c.id, co.command, c.options
from checks c from checks c

View File

@ -0,0 +1,39 @@
package main
import (
"bytes"
"fmt"
"testing"
)
func TestStringToShellFields(t *testing.T) {
type S struct {
source string
target []string
}
for i, e := range []S{
S{"foo", []string{"foo"}},
S{"foo bar", []string{"foo", "bar"}},
S{`foo "bar"`, []string{"foo", `bar`}},
S{`foo "bar baz"`, []string{"foo", `bar baz`}},
S{`foo "bar" "baz"`, []string{"foo", `bar`, `baz`}},
S{`foo "bar" "baz"`, []string{"foo", `bar`, `baz`}},
} {
result := stringToShellFields([]byte(e.source))
if err := compare(e.target, result); err != nil {
t.Errorf("test %d did not match: %s", i, err)
}
}
}
func compare(source []string, target [][]byte) error {
if len(source) != len(target) {
return fmt.Errorf("length mismatch %d vs %d", len(source), len(target))
}
for i, e := range source {
if bytes.Compare([]byte(e), target[i]) != 0 {
return fmt.Errorf("mismatch in content field %d: %s vs %s", i, e, target[i])
}
}
return nil
}