Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add JWT decode function #59

Merged
merged 6 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang/protobuf v1.5.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
Expand Down
22 changes: 22 additions & 0 deletions pkg/functions/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package functions

import (
"fmt"
)

const (
errorPrefix = "JMESPath function '%s': "
invalidArgumentTypeError = errorPrefix + "argument #%d is not of type %s"
genericError = errorPrefix + "%s"
argOutOfBoundsError = errorPrefix + "%d argument is out of bounds (%d)"
zeroDivisionError = errorPrefix + "Zero divisor passed"
nonIntModuloError = errorPrefix + "Non-integer argument(s) passed for modulo"
typeMismatchError = errorPrefix + "Types mismatch"
nonIntRoundError = errorPrefix + "Non-integer argument(s) passed for round off"
)

func formatError(format string, function string, values ...any) error {
args := []any{function}
args = append(args, values...)
return fmt.Errorf(format, args...)
}
82 changes: 82 additions & 0 deletions pkg/functions/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package functions

import (
"encoding/base64"
"fmt"
"reflect"

"github.com/golang-jwt/jwt"
"github.com/jmespath-community/go-jmespath/pkg/functions"
)

func GetFunctions() []functions.FunctionEntry {
return []functions.FunctionEntry{{
Name: "jwt_decode",
Arguments: []functions.ArgSpec{
{Types: []functions.JpType{functions.JpString}},
{Types: []functions.JpType{functions.JpString}},
},
Handler: jwt_decode,
}}
}

func jwt_decode(arguments []any) (any, error) {

// Validate argument
tokenString, err := validateArg(" ", arguments, 0, reflect.String)
if err != nil {
return nil, fmt.Errorf("invalidArgumentTypeError: %w", err)
}
tokenStringVal := tokenString.String()

secretkey, err := validateArg(" ", arguments, 1, reflect.String)
if err != nil {
return nil, fmt.Errorf("invalidArgumentTypeError: %w", err)
}

// Attempt to decode the base64 encoded secret key
decodedKey, err := base64.StdEncoding.DecodeString(secretkey.String())
if err != nil {
// If decoding fails, assume the secret key is not base64 encoded
decodedKey = []byte(secretkey.String())
}

token, err := jwt.Parse(tokenStringVal, func(token *jwt.Token) (interface{}, error) {
return decodedKey, nil
})
if err != nil {
return nil, fmt.Errorf("invalid JWT token: %w", err)
}

// Convert header and payload to regular maps
headerMap := make(map[string]interface{})
for k, v := range jwt.MapClaims(token.Header) {
headerMap[k] = v
}

payloadMap := make(map[string]interface{})
for k, v := range jwt.MapClaims(token.Claims.(jwt.MapClaims)) {
payloadMap[k] = v
}

result := map[string]any{
"header": headerMap,
"payload": payloadMap,
"sig": fmt.Sprintf("%x", token.Signature),
}
return result, nil
}

func validateArg(f string, arguments []any, index int, expectedType reflect.Kind) (reflect.Value, error) {
if index >= len(arguments) {
return reflect.Value{}, formatError(argOutOfBoundsError, f, index+1, len(arguments))
}
if arguments[index] == nil {
return reflect.Value{}, formatError(invalidArgumentTypeError, f, index+1, expectedType.String())
}
arg := reflect.ValueOf(arguments[index])
if arg.Type().Kind() != expectedType {
return reflect.Value{}, formatError(invalidArgumentTypeError, f, index+1, expectedType.String())
}
return arg, nil
}
82 changes: 82 additions & 0 deletions pkg/functions/functions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package functions

import (
"fmt"
"reflect"

"testing"
)

func Test_jwt_decode(t *testing.T) {

token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIyNDEwODE1MzksIm5iZiI6MTUxNDg1MTEzOSwicm9sZSI6Imd1ZXN0Iiwic3ViIjoiWVd4cFkyVT0ifQ.ja1bgvIt47393ba_WbSBm35NrUhdxM4mOVQN8iXz8lk"
secret := "c2VjcmV0"
type args struct {
arguments []any
}
tests := []struct {
name string
args args
want map[string]any
wantErr bool
}{
{
name: "Positive case , function returns what we expected",
args: args{[]any{token, secret}},
want: map[string]interface{}{
"header": map[string]interface{}{
"alg": "HS256",
"typ": "JWT",
},
"payload": map[string]interface{}{
"exp": 2.241081539e+09,
"nbf": 1.514851139e+09,
"role": "guest",
"sub": "YWxpY2U=",
},
"sig": fmt.Sprintf("%x", []byte{0x6a, 0x61, 0x31, 0x62, 0x67, 0x76, 0x49, 0x74, 0x34, 0x37, 0x33, 0x39, 0x33, 0x62, 0x61, 0x5f, 0x57, 0x62, 0x53, 0x42, 0x6d, 0x33, 0x35, 0x4e, 0x72, 0x55, 0x68, 0x64, 0x78, 0x4d, 0x34, 0x6d, 0x4f, 0x56, 0x51, 0x4e, 0x38, 0x69, 0x58, 0x7a, 0x38, 0x6c, 0x6b}),
},
wantErr: false,
},
// Negative test case: passing incorrect arguments (invalid token)
{
name: "negative case - invalid token",
args: args{[]any{"invalid_jwt_token", secret}},
want: map[string]interface{}{
"header": nil,
"payload": nil,
"sig": nil,
},
wantErr: true,
},
// Negative test case: passing incorrect arguments (invalid secret)
{
name: "negative case - invalid secret",
args: args{[]any{token, "invalid_secret"}},
want: map[string]interface{}{
"header": nil,
"payload": nil,
"sig": nil,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := jwt_decode(tt.args.arguments)
if (err != nil) != tt.wantErr {
t.Errorf("jwt_decode() error = %v, wantErr %v", err, tt.wantErr)
return
}

if !tt.wantErr {
gotValue := reflect.ValueOf(got)
wantValue := reflect.ValueOf(tt.want)

if !reflect.DeepEqual(gotValue.Interface(), wantValue.Interface()) {
t.Errorf("jwt_decode() = %v, want %v", gotValue.Interface(), wantValue.Interface())
}
}
})
}
}
43 changes: 43 additions & 0 deletions pkg/scratch/scratch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package scratch

import (
"context"
"fmt"

jpfunctions "github.com/jmespath-community/go-jmespath/pkg/functions"
"github.com/jmespath-community/go-jmespath/pkg/interpreter"
"github.com/jmespath-community/go-jmespath/pkg/parsing"
function "github.com/kyverno/kyverno-envoy-plugin/pkg/functions"
"github.com/kyverno/kyverno-json/pkg/engine/template"
)

var Caller = func() interpreter.FunctionCaller {
var funcs []jpfunctions.FunctionEntry
funcs = append(funcs, template.GetFunctions(context.Background())...)
funcs = append(funcs, function.GetFunctions()...)
return interpreter.NewFunctionCaller(funcs...)
}()

Expand All @@ -29,3 +32,43 @@ func GetUser(authorisation string) (string, error) {
}
return out.(string), nil
}

func GetFormJWTToken(arguments []any) (map[string]interface{}, error) {
vm := interpreter.NewInterpreter(nil, nil)
parser := parsing.NewParser()

// Construct JMESPath expression with arguments
arg1 := fmt.Sprintf("'%s'", arguments[0])
arg2 := fmt.Sprintf("'%s'", arguments[1])
statement := fmt.Sprintf("jwt_decode(%s, %s)", arg1, arg2)

compiled, err := parser.Parse(statement)
if err != nil {
return nil, fmt.Errorf("error on compiling , %w", err)
}
out, err := vm.Execute(compiled, arguments, interpreter.WithFunctionCaller(Caller))
if err != nil {
return nil, fmt.Errorf("error on execute , %w", err)
}
return out.(map[string]interface{}), nil
}

func GetFormJWTTokenPayload(arguments []any) (map[string]interface{}, error) {
vm := interpreter.NewInterpreter(nil, nil)
parser := parsing.NewParser()

// Construct JMESPath expression with arguments
arg1 := fmt.Sprintf("'%s'", arguments[0])
arg2 := fmt.Sprintf("'%s'", arguments[1])
statement := fmt.Sprintf("jwt_decode(%s, %s).payload", arg1, arg2)

compiled, err := parser.Parse(statement)
if err != nil {
return nil, fmt.Errorf("error on compiling , %w", err)
}
out, err := vm.Execute(compiled, arguments, interpreter.WithFunctionCaller(Caller))
if err != nil {
return nil, fmt.Errorf("error on execute , %w", err)
}
return out.(map[string]interface{}), nil
}
102 changes: 101 additions & 1 deletion pkg/scratch/scratch_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package scratch

import "testing"
import (
"reflect"
"testing"
)

func TestGetUser(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -28,3 +31,100 @@ func TestGetUser(t *testing.T) {
})
}
}

func TestGetFormJWTToken(t *testing.T) {

type args struct {
arguments []any
}

tests := []struct {
name string
args args
want map[string]interface{}
wantErr bool
}{
{
name: "positive case - passing correct arguement",
args: args{[]any{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIyNDEwODE1MzksIm5iZiI6MTUxNDg1MTEzOSwicm9sZSI6Imd1ZXN0Iiwic3ViIjoiWVd4cFkyVT0ifQ.ja1bgvIt47393ba_WbSBm35NrUhdxM4mOVQN8iXz8lk", "c2VjcmV0"}},
want: map[string]interface{}{
"header": map[string]interface{}{
"alg": "HS256",
"typ": "JWT",
},
"payload": map[string]interface{}{
"exp": 2.241081539e+09,
"nbf": 1.514851139e+09,
"role": "guest",
"sub": "YWxpY2U=",
},
"sig": "6a61316267764974343733393362615f576253426d33354e72556864784d346d4f56514e3869587a386c6b",
},
wantErr: false,
},
// Negative test case: passing incorrect arguments
{
name: "negative case - incorrect arguments",
args: args{[]any{"invalid_jwt_token", "c2VjcmV0"}},
want: nil,
// Expecting an error because of the invalid JWT token
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GetFormJWTToken(tt.args.arguments)
if (err != nil) != tt.wantErr {
t.Errorf("GetFormJWTToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetFormJWTToken() = %v, want %v", got, tt.want)
}
})
}
}

func TestGetFormJWTTokenPayload(t *testing.T) {
type args struct {
arguments []any
}
tests := []struct {
name string
args args
want map[string]interface{}
wantErr bool
}{
{
name: "Positive case",
args: args{[]any{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIyNDEwODE1MzksIm5iZiI6MTUxNDg1MTEzOSwicm9sZSI6Imd1ZXN0Iiwic3ViIjoiWVd4cFkyVT0ifQ.ja1bgvIt47393ba_WbSBm35NrUhdxM4mOVQN8iXz8lk", "c2VjcmV0"}},
Sanskarzz marked this conversation as resolved.
Show resolved Hide resolved
want: map[string]interface{}{
"exp": 2.241081539e+09,
"nbf": 1.514851139e+09,
"role": "guest",
"sub": "YWxpY2U=",
},
wantErr: false,
},
// Negative test case: passing incorrect arguments
{
name: "negative case - incorrect arguments",
args: args{[]any{"invalid_jwt_token", "c2VjcmV0"}},
want: nil,
// Expecting an error because of the invalid JWT token
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GetFormJWTTokenPayload(tt.args.arguments)
if (err != nil) != tt.wantErr {
t.Errorf("GetFormJWTTokenPayload() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetFormJWTTokenPayload() = %v, want %v", got, tt.want)
}
})
}
}