Skip to content

Commit

Permalink
fix: comment (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmyj authored Sep 27, 2020
1 parent b6521e1 commit 4eb0aa4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
12 changes: 8 additions & 4 deletions apiaccessor/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ import (
"errors"
)

// Accessor is the interface use to check the availability of the request's arguments
type Accessor interface {
CheckSignature() error
CheckTimestamp() error
CheckNonce() error
}

var (
ErrArgLack = errors.New("arg lack")
ErrSignatureUnmatched = errors.New("signature is unmatched")
ErrTimestampTimeout = errors.New("timestamp time out")
ErrNonceUsed = errors.New("nonce is used")
errArgLack = errors.New("arg lack")
errSignatureUnmatched = errors.New("signature is unmatched")
errTimestampTimeout = errors.New("timestamp time out")
errNonceUsed = errors.New("nonce is used")
)

const (
Expand Down Expand Up @@ -46,8 +47,11 @@ func (a *args) append(k, v string) {
a.l = append(a.l, &arg{k: k, v: v})
}

// EvalSignature evaluating the signature of the request's arguments
type EvalSignature func(origin string) (signature string)

// TimestampChecker checking the availability of the request's timestamp argument
type TimestampChecker func(timestamp int64) error

// NonceChecker checking the availability of the request's nonce argument
type NonceChecker func(nonce string) error
7 changes: 5 additions & 2 deletions apiaccessor/base_accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func defTimestampChecker(timestamp int64) error {
const sec = 5
dt := time.Now().Unix() - timestamp
if dt > sec || dt < -sec {
return ErrTimestampTimeout
return errTimestampTimeout
}
return nil
}
Expand All @@ -49,6 +49,7 @@ func newBaseAccessor() baseAccessor {
}
}

// CheckSignature implements the Accessor CheckSignature interface
func (a *baseAccessor) CheckSignature() error {
// 参数排序
sort.Slice(a.args.l, func(i, j int) bool {
Expand All @@ -75,11 +76,12 @@ func (a *baseAccessor) CheckSignature() error {
signature := a.evalSignatureFunc(argText)
argSignature := a.args.kv[signatureTag]
if signature != argSignature {
return fmt.Errorf("%w: want %s, get %s", ErrSignatureUnmatched, signature, argSignature)
return fmt.Errorf("%w: want %s, get %s", errSignatureUnmatched, signature, argSignature)
}
return nil
}

// CheckTimestamp implements the Accessor CheckTimestamp interface
func (a *baseAccessor) CheckTimestamp() error {
timestampStr := a.args.kv[timestampTag]
timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
Expand All @@ -89,6 +91,7 @@ func (a *baseAccessor) CheckTimestamp() error {
return a.timestampChecker(timestamp)
}

// CheckNonce implements the Accessor CheckNonce interface
func (a *baseAccessor) CheckNonce() error {
return a.nonceChecker(a.args.kv[nonceTag])
}
10 changes: 6 additions & 4 deletions apiaccessor/query_accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,32 @@ import (
"net/url"
)

// QueryAccessor use to check the availability of the request's query arguments
type QueryAccessor struct {
baseAccessor
}

// NewQueryAccessor creates a new QueryAccessor.
func NewQueryAccessor(query url.Values, secretKey string, setters ...Setter) (*QueryAccessor, error) {
qa := &QueryAccessor{
baseAccessor: newBaseAccessor(),
}
for key, vs := range query {
v := vs[0]
if len(v) == 0 {
return nil, fmt.Errorf("%w: %s", ErrArgLack, key)
return nil, fmt.Errorf("%w: %s", errArgLack, key)
}
qa.args.append(key, v)
}
qa.args.append(secretKeyTag, secretKey)
if len(qa.args.kv[nonceTag]) == 0 {
return nil, fmt.Errorf("%w: %s", ErrArgLack, nonceTag)
return nil, fmt.Errorf("%w: %s", errArgLack, nonceTag)
}
if len(qa.args.kv[timestampTag]) == 0 {
return nil, fmt.Errorf("%w: %s", ErrArgLack, timestampTag)
return nil, fmt.Errorf("%w: %s", errArgLack, timestampTag)
}
if len(qa.args.kv[signatureTag]) == 0 {
return nil, fmt.Errorf("%w: %s", ErrArgLack, signatureTag)
return nil, fmt.Errorf("%w: %s", errArgLack, signatureTag)
}

for _, setter := range setters {
Expand Down
12 changes: 6 additions & 6 deletions apiaccessor/query_accessor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ import (
func TestNewQueryAccessor(t *testing.T) {
query := url.Values{}
_, err := NewQueryAccessor(query, "123")
assert.Equal(t, errors.Is(err, ErrArgLack), true)
assert.Equal(t, errors.Is(err, errArgLack), true)

query = url.Values{
nonceTag: []string{"12345"},
}
_, err = NewQueryAccessor(query, "123")
assert.Equal(t, errors.Is(err, ErrArgLack), true)
assert.Equal(t, errors.Is(err, errArgLack), true)

query = url.Values{
nonceTag: []string{"12345"},
Expand Down Expand Up @@ -52,7 +52,7 @@ func TestCheckSignature(t *testing.T) {
qa, err := NewQueryAccessor(query, "123")
assert.Equal(t, err, nil)
err = qa.CheckSignature()
assert.Equal(t, errors.Is(err, ErrSignatureUnmatched), true)
assert.Equal(t, errors.Is(err, errSignatureUnmatched), true)

query = url.Values{
nonceTag: []string{"12345"},
Expand All @@ -78,7 +78,7 @@ func TestCheckTimestamp(t *testing.T) {
qa, err := NewQueryAccessor(query, "123")
assert.Equal(t, err, nil)
err = qa.CheckTimestamp()
assert.Equal(t, errors.Is(err, ErrTimestampTimeout), true)
assert.Equal(t, errors.Is(err, errTimestampTimeout), true)

query = url.Values{
nonceTag: []string{"12345"},
Expand All @@ -97,7 +97,7 @@ func TestCheckNonce(t *testing.T) {
nonceMap := make(map[string]bool)
mockNonceChecker := func(nonce string) error {
if _, ok := nonceMap[nonce]; ok {
return ErrNonceUsed
return errNonceUsed
}
nonceMap[nonce] = true
return nil
Expand All @@ -115,5 +115,5 @@ func TestCheckNonce(t *testing.T) {
err = qa.CheckNonce()
assert.Equal(t, err, nil)
err = qa.CheckNonce()
assert.Equal(t, errors.Is(err, ErrNonceUsed), true)
assert.Equal(t, errors.Is(err, errNonceUsed), true)
}
7 changes: 6 additions & 1 deletion apiaccessor/setter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package apiaccessor

import "github.com/go-redis/redis/v7"

// Setter is the option of creating the Accessor
type Setter func(b *baseAccessor) error

// WithEvalSignatureFunc set a custom EvalSignature for the Accessor
func WithEvalSignatureFunc(e EvalSignature) Setter {
return func(b *baseAccessor) error {
b.evalSignatureFunc = e
Expand Down Expand Up @@ -31,8 +33,10 @@ end
return no
`)

// KeyGen use to generate a redis key which is using in the WithGeneralRedisNonceChecker
type KeyGen func(nonce string) (key string)

// WithGeneralRedisNonceChecker set a redis-base NonceChecker for the Accessor
func WithGeneralRedisNonceChecker(client redis.Cmdable, sec int64, keyGenFunc KeyGen) Setter {
return func(b *baseAccessor) error {
b.nonceChecker = func(nonce string) error {
Expand All @@ -42,14 +46,15 @@ func WithGeneralRedisNonceChecker(client redis.Cmdable, sec int64, keyGenFunc Ke
return err
}
if re == 1 {
return ErrNonceUsed
return errNonceUsed
}
return nil
}
return nil
}
}

// WithNonceChecker set a custom NonceChecker for the Accessor
func WithNonceChecker(nc NonceChecker) Setter {
return func(b *baseAccessor) error {
b.nonceChecker = nc
Expand Down

0 comments on commit 4eb0aa4

Please sign in to comment.