Skip to content

Commit c183ab4

Browse files
committed
adjustment code
1 parent 07e9f65 commit c183ab4

File tree

11 files changed

+86
-40
lines changed

11 files changed

+86
-40
lines changed

internal/dao/userExample.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ type UserExampleDao interface {
3131
GetByColumns(ctx context.Context, params *query.Params) ([]*model.UserExample, int64, error)
3232

3333
CreateByTx(ctx context.Context, tx *gorm.DB, table *model.UserExample) (uint64, error)
34-
UpdateByTx(ctx context.Context, tx *gorm.DB, table *model.UserExample) error
3534
DeleteByTx(ctx context.Context, tx *gorm.DB, id uint64) error
35+
UpdateByTx(ctx context.Context, tx *gorm.DB, table *model.UserExample) error
3636
}
3737

3838
type userExampleDao struct {

internal/routers/userExample.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ func init() {
1414

1515
func userExampleRouter(group *gin.RouterGroup, h handler.UserExampleHandler) {
1616
//group.Use(middleware.Auth()) // all of the following routes use jwt authentication
17+
// or group.Use(middleware.Auth(middleware.WithVerify(verify))) // token authentication
18+
1719
group.POST("/userExample", h.Create)
1820
group.DELETE("/userExample/:id", h.DeleteByID)
1921
group.POST("/userExample/delete/ids", h.DeleteByIDs)

internal/service/userExample.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ func (s *userExample) Create(ctx context.Context, req *serverNameExampleV1.Creat
5050
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
5151
return nil, ecode.StatusInvalidParams.Err()
5252
}
53+
ctx = interceptor.WrapServerCtx(ctx)
5354

5455
record := &model.UserExample{}
5556
err = copier.Copy(record, req)
5657
if err != nil {
5758
return nil, ecode.StatusCreateUserExample.Err()
5859
}
5960

60-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
6161
err = s.iDao.Create(ctx, record)
6262
if err != nil {
6363
logger.Error("Create error", logger.Err(err), logger.Any("userExample", record), interceptor.ServerCtxRequestIDField(ctx))
@@ -74,8 +74,8 @@ func (s *userExample) DeleteByID(ctx context.Context, req *serverNameExampleV1.D
7474
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
7575
return nil, ecode.StatusInvalidParams.Err()
7676
}
77+
ctx = interceptor.WrapServerCtx(ctx)
7778

78-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
7979
err = s.iDao.DeleteByID(ctx, req.Id)
8080
if err != nil {
8181
logger.Error("DeleteByID error", logger.Err(err), logger.Any("id", req.Id), interceptor.ServerCtxRequestIDField(ctx))
@@ -92,8 +92,8 @@ func (s *userExample) DeleteByIDs(ctx context.Context, req *serverNameExampleV1.
9292
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
9393
return nil, ecode.StatusInvalidParams.Err()
9494
}
95+
ctx = interceptor.WrapServerCtx(ctx)
9596

96-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
9797
err = s.iDao.DeleteByIDs(ctx, req.Ids)
9898
if err != nil {
9999
logger.Error("DeleteByID error", logger.Err(err), logger.Any("ids", req.Ids), interceptor.ServerCtxRequestIDField(ctx))
@@ -110,6 +110,7 @@ func (s *userExample) UpdateByID(ctx context.Context, req *serverNameExampleV1.U
110110
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
111111
return nil, ecode.StatusInvalidParams.Err()
112112
}
113+
ctx = interceptor.WrapServerCtx(ctx)
113114

114115
record := &model.UserExample{}
115116
err = copier.Copy(record, req)
@@ -118,7 +119,6 @@ func (s *userExample) UpdateByID(ctx context.Context, req *serverNameExampleV1.U
118119
}
119120
record.ID = req.Id
120121

121-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
122122
err = s.iDao.UpdateByID(ctx, record)
123123
if err != nil {
124124
logger.Error("UpdateByID error", logger.Err(err), logger.Any("userExample", record), interceptor.ServerCtxRequestIDField(ctx))
@@ -135,8 +135,8 @@ func (s *userExample) GetByID(ctx context.Context, req *serverNameExampleV1.GetU
135135
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
136136
return nil, ecode.StatusInvalidParams.Err()
137137
}
138+
ctx = interceptor.WrapServerCtx(ctx)
138139

139-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
140140
record, err := s.iDao.GetByID(ctx, req.Id)
141141
if err != nil {
142142
if errors.Is(err, query.ErrNotFound) {
@@ -163,8 +163,8 @@ func (s *userExample) GetByCondition(ctx context.Context, req *serverNameExample
163163
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
164164
return nil, ecode.StatusInvalidParams.Err()
165165
}
166+
ctx = interceptor.WrapServerCtx(ctx)
166167

167-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
168168
conditions := &query.Conditions{}
169169
for _, v := range req.Conditions.GetColumns() {
170170
column := query.Column{}
@@ -205,8 +205,8 @@ func (s *userExample) ListByIDs(ctx context.Context, req *serverNameExampleV1.Li
205205
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
206206
return nil, ecode.StatusInvalidParams.Err()
207207
}
208+
ctx = interceptor.WrapServerCtx(ctx)
208209

209-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
210210
userExampleMap, err := s.iDao.GetByIDs(ctx, req.Ids)
211211
if err != nil {
212212
logger.Error("GetByIDs error", logger.Err(err), logger.Any("ids", req.Ids), interceptor.ServerCtxRequestIDField(ctx))
@@ -235,6 +235,7 @@ func (s *userExample) List(ctx context.Context, req *serverNameExampleV1.ListUse
235235
logger.Warn("req.Validate error", logger.Err(err), logger.Any("req", req), interceptor.ServerCtxRequestIDField(ctx))
236236
return nil, ecode.StatusInvalidParams.Err()
237237
}
238+
ctx = interceptor.WrapServerCtx(ctx)
238239

239240
params := &query.Params{}
240241
err = copier.Copy(params, req.Params)
@@ -243,7 +244,6 @@ func (s *userExample) List(ctx context.Context, req *serverNameExampleV1.ListUse
243244
}
244245
params.Size = int(req.Params.Limit)
245246

246-
ctx = context.WithValue(ctx, interceptor.ContextRequestIDKey, interceptor.ServerCtxRequestID(ctx)) //nolint
247247
records, total, err := s.iDao.GetByColumns(ctx, params)
248248
if err != nil {
249249
if strings.Contains(err.Error(), "query params error:") {

pkg/gin/middleware/README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,14 @@ func main() {
8585
r.Run(serverAddr)
8686
}
8787

88-
func adminVerify(claims *jwt.Claims) error {
88+
func adminVerify(claims *jwt.Claims, token string) error {
8989
if claims.Role != "admin" {
9090
return errors.New("verify failed")
9191
}
92+
93+
// token := getToken(claims.UID)
94+
// if tokenTail10 != token[len(token)-10:] { return err }
95+
9296
return nil
9397
}
9498

@@ -116,9 +120,12 @@ func main() {
116120
r.Run(serverAddr)
117121
}
118122

119-
func verify(claims *jwt.CustomClaims) error {
123+
func verify(claims *jwt.CustomClaims, tokenTail10 string) error {
120124
err := errors.New("verify failed")
121125

126+
// token := getToken(id)
127+
// if tokenTail10 != token[len(token)-10:] { return err }
128+
122129
id, exist := claims.Get("id")
123130
if !exist {
124131
return err

pkg/gin/middleware/auth.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ func responseUnauthorized(c *gin.Context, isSwitchHTTPCode bool) {
6060

6161
// -------------------------------------------------------------------------------------------
6262

63-
// VerifyFn verify function
64-
type VerifyFn func(claims *jwt.Claims) error
63+
// VerifyFn verify function, tokenTail10 is a string that intercepts the last 10 characters of the token.
64+
type VerifyFn func(claims *jwt.Claims, tokenTail10 string) error
6565

6666
// Auth authorization
6767
func Auth(opts ...JwtOption) gin.HandlerFunc {
@@ -77,7 +77,8 @@ func Auth(opts ...JwtOption) gin.HandlerFunc {
7777
return
7878
}
7979

80-
claims, err := jwt.ParseToken(authorization[7:]) // token=authorization[7:], remove Bearer prefix
80+
token := authorization[7:] // remove Bearer prefix
81+
claims, err := jwt.ParseToken(token)
8182
if err != nil {
8283
logger.Warn("ParseToken error", logger.Err(err))
8384
responseUnauthorized(c, o.isSwitchHTTPCode)
@@ -86,7 +87,8 @@ func Auth(opts ...JwtOption) gin.HandlerFunc {
8687
}
8788

8889
if o.verify != nil {
89-
if err = o.verify(claims); err != nil {
90+
tokenTail10 := token[len(token)-10:]
91+
if err = o.verify(claims, tokenTail10); err != nil {
9092
logger.Warn("verify error", logger.Err(err), logger.String("uid", claims.UID), logger.String("role", claims.Role))
9193
responseUnauthorized(c, o.isSwitchHTTPCode)
9294
c.Abort()
@@ -103,8 +105,8 @@ func Auth(opts ...JwtOption) gin.HandlerFunc {
103105

104106
// -------------------------------------------------------------------------------------------
105107

106-
// VerifyCustomFn verify custom function
107-
type VerifyCustomFn func(claims *jwt.CustomClaims) error
108+
// VerifyCustomFn verify custom function, tokenTail10 is a string that intercepts the last 10 characters of the token.
109+
type VerifyCustomFn func(claims *jwt.CustomClaims, tokenTail10 string) error
108110

109111
// AuthCustom custom authentication
110112
func AuthCustom(verify VerifyCustomFn, opts ...JwtOption) gin.HandlerFunc {
@@ -120,15 +122,17 @@ func AuthCustom(verify VerifyCustomFn, opts ...JwtOption) gin.HandlerFunc {
120122
return
121123
}
122124

123-
claims, err := jwt.ParseCustomToken(authorization[7:]) // token=authorization[7:], remove Bearer prefix
125+
token := authorization[7:] // remove Bearer prefix
126+
claims, err := jwt.ParseCustomToken(token)
124127
if err != nil {
125128
logger.Warn("ParseToken error", logger.Err(err))
126129
responseUnauthorized(c, o.isSwitchHTTPCode)
127130
c.Abort()
128131
return
129132
}
130133

131-
if err = verify(claims); err != nil {
134+
tokenTail10 := token[len(token)-10:]
135+
if err = verify(claims, tokenTail10); err != nil {
132136
logger.Warn("verify error", logger.Err(err), logger.Any("fields", claims.Fields))
133137
responseUnauthorized(c, o.isSwitchHTTPCode)
134138
c.Abort()

pkg/gin/middleware/auth_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ var (
2222
fields = jwt.KV{"id": 1, "foo": "bar"}
2323
)
2424

25-
func verify(claims *jwt.Claims) error {
25+
func verify(claims *jwt.Claims, tokenTail10 string) error {
2626
if claims.UID != uid || claims.Role != role {
2727
return errors.New("verify failed")
2828
}
29+
30+
// token := getToken(claims.UID)
31+
// if token[len(token)-10:] != tokenTail10 { return err }
32+
2933
return nil
3034
}
3135

32-
func verifyCustom(claims *jwt.CustomClaims) error {
36+
func verifyCustom(claims *jwt.CustomClaims, tokenTail10 string) error {
3337
err := errors.New("verify failed")
3438

3539
id, exist := claims.Get("id")
@@ -44,6 +48,9 @@ func verifyCustom(claims *jwt.CustomClaims) error {
4448
return err
4549
}
4650

51+
// token := getToken(id)
52+
// if token[len(token)-10:] != tokenTail10 { return err }
53+
4754
return nil
4855
}
4956

pkg/grpc/interceptor/requstid.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@ type CtxKeyString string
3434
// RequestIDKey request_id
3535
var RequestIDKey = CtxKeyString(ContextRequestIDKey)
3636

37+
// ---------------------------------- client interceptor ----------------------------------
38+
3739
// CtxRequestIDField get request id field from context.Context
3840
func CtxRequestIDField(ctx context.Context) zap.Field {
3941
return zap.String(ContextRequestIDKey, metautils.ExtractOutgoing(ctx).Get(ContextRequestIDKey))
4042
}
4143

42-
// ---------------------------------- client interceptor ----------------------------------
43-
4444
// ClientCtxRequestID get request id from rpc client context.Context
4545
func ClientCtxRequestID(ctx context.Context) string {
4646
return metautils.ExtractOutgoing(ctx).Get(ContextRequestIDKey)
4747
}
4848

4949
// ClientCtxRequestIDField get request id field from rpc client context.Context
5050
func ClientCtxRequestIDField(ctx context.Context) zap.Field {
51-
return zap.String(ContextRequestIDKey, ClientCtxRequestID(ctx))
51+
return zap.String(ContextRequestIDKey, metautils.ExtractOutgoing(ctx).Get(ContextRequestIDKey))
5252
}
5353

5454
// UnaryClientRequestID client-side request_id unary interceptor
@@ -79,14 +79,29 @@ func StreamClientRequestID() grpc.StreamClientInterceptor {
7979

8080
// ---------------------------------- server interceptor ----------------------------------
8181

82+
// KV key value
83+
type KV struct {
84+
Key string
85+
Val interface{}
86+
}
87+
88+
// WrapServerCtx wrap context, used in grpc server-side
89+
func WrapServerCtx(ctx context.Context, kvs ...KV) context.Context {
90+
ctx = context.WithValue(ctx, ContextRequestIDKey, metautils.ExtractIncoming(ctx).Get(ContextRequestIDKey)) //nolint
91+
for _, kv := range kvs {
92+
ctx = context.WithValue(ctx, kv.Key, kv.Val) //nolint
93+
}
94+
return ctx
95+
}
96+
8297
// ServerCtxRequestID get request id from rpc server context.Context
8398
func ServerCtxRequestID(ctx context.Context) string {
8499
return metautils.ExtractIncoming(ctx).Get(ContextRequestIDKey)
85100
}
86101

87102
// ServerCtxRequestIDField get request id field from rpc server context.Context
88103
func ServerCtxRequestIDField(ctx context.Context) zap.Field {
89-
return zap.String(ContextRequestIDKey, ServerCtxRequestID(ctx))
104+
return zap.String(ContextRequestIDKey, metautils.ExtractIncoming(ctx).Get(ContextRequestIDKey))
90105
}
91106

92107
// UnaryServerRequestID server-side request_id unary interceptor

pkg/grpc/interceptor/requstid_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,12 +554,19 @@ func TestStreamServerRequestID(t *testing.T) {
554554
time.Sleep(time.Millisecond)
555555
}
556556

557-
func TestCtxRequestIDField(t *testing.T) {
558-
field := ClientCtxRequestIDField(context.Background())
557+
func TestCtxRequestID(t *testing.T) {
558+
_ = ClientCtxRequestID(context.Background())
559+
field := CtxRequestIDField(context.Background())
559560
assert.NotNil(t, field)
560-
field = ServerCtxRequestIDField(context.Background())
561+
field = ClientCtxRequestIDField(context.Background())
561562
assert.NotNil(t, field)
562-
field = CtxRequestIDField(context.Background())
563+
564+
ctx := WrapServerCtx(context.Background())
565+
assert.NotNil(t, ctx)
566+
ctx = WrapServerCtx(context.Background(), KV{Key: "foo", Val: "bar"})
567+
assert.NotNil(t, ctx)
568+
_ = ServerCtxRequestID(context.Background())
569+
field = ServerCtxRequestIDField(context.Background())
563570
assert.NotNil(t, field)
564571
}
565572

pkg/jwt/jwt.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ import (
77
"github.com/golang-jwt/jwt/v5"
88
)
99

10+
// ErrTokenExpired expired
11+
var ErrTokenExpired = jwt.ErrTokenExpired
12+
13+
var opt *options
14+
15+
// Init initialize jwt
16+
func Init(opts ...Option) {
17+
o := defaultOptions()
18+
o.apply(opts...)
19+
opt = o
20+
}
21+
1022
// Claims my custom claims
1123
type Claims struct {
1224
UID string `json:"uid"`

pkg/jwt/jwt_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package jwt
22

33
import (
4+
"errors"
45
"fmt"
56
"testing"
67
"time"
@@ -68,7 +69,7 @@ func TestParseToken(t *testing.T) {
6869
}
6970
time.Sleep(time.Second * 2)
7071
v, err = ParseToken(token)
71-
assert.Error(t, err)
72+
assert.True(t, errors.Is(err, ErrTokenExpired))
7273
}
7374

7475
func TestGenerateCustomToken(t *testing.T) {
@@ -132,5 +133,5 @@ func TestParseCustomToken(t *testing.T) {
132133
}
133134
time.Sleep(time.Second * 2)
134135
v, err = ParseCustomToken(token)
135-
assert.Error(t, err)
136+
assert.True(t, errors.Is(err, ErrTokenExpired))
136137
}

0 commit comments

Comments
 (0)