diff --git a/rules/engine.go b/rules/engine.go index 2f0e2f4..3f3dd69 100644 --- a/rules/engine.go +++ b/rules/engine.go @@ -3,6 +3,7 @@ package rules import ( "fmt" "os" + "regexp" "strings" "time" @@ -69,7 +70,7 @@ type V3Engine interface { AddRule(rule DynamicRule, lockPattern string, callback V3RuleTaskCallback, - options ...RuleOption) + options ...RuleOption) error AddPolling(namespacePattern string, preconditions DynamicRule, ttl int, @@ -169,8 +170,13 @@ func (e *v3Engine) SetWatcherWrapper(watcherWrapper WrapWatcher) { func (e *v3Engine) AddRule(rule DynamicRule, lockPattern string, callback V3RuleTaskCallback, - options ...RuleOption) { + options ...RuleOption) error { + validPath := regexp.MustCompile(`^[[:alnum:] \:\/\"\'\_\.\,\*\=\-]*$`) + if !validPath.MatchString(lockPattern) { + return fmt.Errorf("Path contains an invalid character") + } e.addRuleWithIface(rule, lockPattern, callback, options...) + return nil } func (e *baseEngine) Stop() { diff --git a/rules/engine_test.go b/rules/engine_test.go index 8335bc2..8863856 100644 --- a/rules/engine_test.go +++ b/rules/engine_test.go @@ -1,6 +1,7 @@ package rules import ( + "fmt" "testing" "time" @@ -30,14 +31,18 @@ func TestV3EngineConstructor(t *testing.T) { eng := NewV3Engine(cfg, getTestLogger()) value := "val" rule, _ := NewEqualsLiteralRule("/key", &value) - eng.AddRule(rule, "/lock", v3DummyCallback, RuleID("test")) - assert.PanicsWithValue(t, "Rule ID option missing", func() { eng.AddRule(rule, "/lock", v3DummyCallback) }) - err := eng.AddPolling("/polling", rule, 30, v3DummyCallback) + err := eng.AddRule(rule, "/lock?@", v3DummyCallback, RuleID("test")) + assert.Equal(t, err, fmt.Errorf("Path contains an invalid character")) + err = eng.AddRule(rule, "/lock", v3DummyCallback, RuleID("test")) + assert.NoError(t, err) + assert.PanicsWithValue(t, "Rule ID option missing", func() { assert.NoError(t, eng.AddRule(rule, "/lock", v3DummyCallback)) }) + err = eng.AddPolling("/polling", rule, 30, v3DummyCallback) assert.NoError(t, err) assertEngineRunStop(t, eng) eng = NewV3Engine(cfg, getTestLogger(), KeyExpansion(map[string][]string{"a:": {"b"}})) - eng.AddRule(rule, "/lock", v3DummyCallback, RuleLockTimeout(30), RuleID("test")) + err = eng.AddRule(rule, "/lock", v3DummyCallback, RuleLockTimeout(30), RuleID("test")) + assert.NoError(t, err) err = eng.AddPolling("/polling", rule, 30, v3DummyCallback) assert.NoError(t, err) err = eng.AddPolling("/polling[", rule, 30, v3DummyCallback) diff --git a/rules/lock/lock.go b/rules/lock/lock.go index 36bb547..c6b4ed4 100644 --- a/rules/lock/lock.go +++ b/rules/lock/lock.go @@ -2,8 +2,6 @@ package lock import ( "errors" - "fmt" - "regexp" "time" "golang.org/x/net/context" @@ -56,10 +54,6 @@ type v3Locker struct { } func (v3l *v3Locker) Lock(key string, options ...Option) (RuleLock, error) { - validPath := regexp.MustCompile(`^[[:alnum:] \/\"\'\_\.\,\*\=\-]+$`) - if !validPath.MatchString(key) { - return nil, fmt.Errorf("Path variable contains an invalid character") - } return v3l.lockWithTimeout(key, v3l.lockTimeout) } func (v3l *v3Locker) lockWithTimeout(key string, timeout int) (RuleLock, error) { diff --git a/rules/lock/lock_test.go b/rules/lock/lock_test.go index 3a41922..1e5962e 100644 --- a/rules/lock/lock_test.go +++ b/rules/lock/lock_test.go @@ -1,7 +1,6 @@ package lock import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -59,39 +58,3 @@ func Test_V3Locker(t *testing.T) { }) } } - -func Test_V3LockerRegex(t *testing.T) { - cfg, cl := teststore.InitV3Etcd(t) - _, err := v3.New(cfg) - require.NoError(t, err) - newSession := func(_ context.Context) (*v3c.Session, error) { - return v3c.NewSession(cl, v3c.WithTTL(30)) - } - - testcases := []struct { - name string - lockKey string - err error - }{ - { - name: "bad regex", - lockKey: "/test?/", - err: fmt.Errorf("Path variable contains an invalid character"), - }, - { - name: "good regex", - lockKey: "/test/", - }, - } - - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - rlckr := v3Locker{ - newSession: newSession, - lockTimeout: 5, - } - _, err := rlckr.Lock(tc.lockKey) - assert.Equal(t, err, tc.err) - }) - } -}