diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9f11b75 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c6e3699 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2018 Lars Wilhelmsen + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..dcd4ec7 --- /dev/null +++ b/Makefile @@ -0,0 +1,2 @@ +test: + go test -v \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..728ec27 --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# Accumulo Access Expressions for Go + +## Introduction + +This package provides a simple way to parse and evaluate Accumulo access expressions in Go, based on the [AccessExpression specification](https://github.com/apache/accumulo-access/blob/main/SPECIFICATION.md). + +## Usage + +```go +package main + +import ( + "fmt" + accumulo "github.com/larsw/accumulo-access-go/pkg" +) + +func main() { + res, err := accumulo.CheckAuthorization("A & B & (C | D)", "A,B,C") + if err != nil { + fmt.Printf("err: %v\n", err) + return + } + // Print the result + fmt.Printf("%v\n", res) +} +``` + +* Lars Wilhelmsen (https://github.com/larsw/) + +## License + +Licensed under the MIT License [LICENSE_MIT](LICENSE). + diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2723e8f --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/larsw/accumulo-access-go + +go 1.18 diff --git a/pkg/check_authorization.go b/pkg/check_authorization.go new file mode 100644 index 0000000..3fbcc02 --- /dev/null +++ b/pkg/check_authorization.go @@ -0,0 +1,50 @@ +// Package pkg Copyright 2024 Lars Wilhelmsen . All rights reserved. +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. +package pkg + +import "strings" + +// CheckAuthorization checks if the given authorizations are allowed to perform the given expression. +// Arguments: +// +// expression: The expression to check. +// authorizations: A comma-separated list of authorizations. +// +// Returns: +// +// True if the authorizations are allowed to perform the expression, false otherwise. +func CheckAuthorization(expression string, authorizations string) (bool, error) { + parser := NewParser(newLexer(expression)) + ast, err := parser.Parse() + if err != nil { + return false, err + } + authorizationMap := make(map[string]bool) + for _, authorization := range strings.Split(authorizations, ",") { + authorizationMap[authorization] = true + } + return ast.Evaluate(authorizationMap), nil +} + +// PrepareAuthorizationCheck returns a function that can be used to check if the given authorizations are allowed to perform the given expression. +// Arguments: +// +// authorizations: A comma-separated list of authorizations. +// +// Returns: +// +// A function that can be used to check if the given authorizations are allowed to perform the given expression. +func PrepareAuthorizationCheck(authorizations string) func(string) (bool, error) { + authorizationMap := make(map[string]bool) + for _, authorization := range strings.Split(authorizations, ",") { + authorizationMap[authorization] = true + } + return func(expression string) (bool, error) { + parser := NewParser(newLexer(expression)) + ast, err := parser.Parse() + if err != nil { + return false, err + } + return ast.Evaluate(authorizationMap), nil + } +} diff --git a/pkg/check_authorization_test.go b/pkg/check_authorization_test.go new file mode 100644 index 0000000..cdc6e58 --- /dev/null +++ b/pkg/check_authorization_test.go @@ -0,0 +1,47 @@ +package pkg + +import ( + "fmt" + "testing" +) + +type testCase struct { + expression string + authorizations string + expected bool +} + +func TestCheckAuthorization(t *testing.T) { + testCases := []testCase{ + {"label1", "label1", true}, + {"label1|label2", "label1", true}, + {"label1&label2", "label1", false}, + {"label1&label2", "label1,label2", true}, + {"label1&(label2 | label3)", "label1", false}, + {"label1&(label2 | label3)", "label1,label3", true}, + {"label1&(label2 | label3)", "label1,label2", true}, + {"(label2 | label3)", "label1", false}, + {"(label2 | label3)", "label2", true}, + {"(label2 & label3)", "label2", false}, + {"((label2 | label3))", "label2", true}, + {"((label2 & label3))", "label2", false}, + {"(((((label2 & label3)))))", "label2", false}, + {"(a & b) & (c & d)", "a,b,c,d", true}, + {"(a & b) & (c & d)", "a,b,c", false}, + {"(a & b) | (c & d)", "a,b,d", true}, + {"(a | b) & (c | d)", "a,d", true}, + {"\"a b c\"", "\"a b c\"", true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("\"%v\" + \"%v\" -> %v", tc.expression, tc.authorizations, tc.expected), func(t *testing.T) { + result, err := CheckAuthorization(tc.expression, tc.authorizations) + if err != nil { + t.Fatal(err) + } + if result != tc.expected { + t.Fatalf("expected %v for %s with %s", tc.expected, tc.expression, tc.authorizations) + } + }) + } +} diff --git a/pkg/lexer.go b/pkg/lexer.go new file mode 100644 index 0000000..d5a625f --- /dev/null +++ b/pkg/lexer.go @@ -0,0 +1,160 @@ +// Package pkg Copyright 2024 Lars Wilhelmsen . All rights reserved. +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. +package pkg + +import ( + "fmt" + "unicode" +) + +// Token represents the different types of tokens. +type Token int + +// Define token constants. +const ( + AccessToken Token = iota + OpenParen + CloseParen + And + Or + None + Error +) + +//func (t Token) String() string { +// switch t { +// case AccessToken: +// return "AccessToken" +// case OpenParen: +// return "(" +// case CloseParen: +// return ")" +// case And: +// return "&" +// case Or: +// return "|" +// case None: +// return "None" +// default: +// return "Unknown" +// } +//} + +// Lexer represents a lexer for tokenizing strings. +type Lexer struct { + input string + pos int + readPos int + ch byte +} + +// newLexer creates a new Lexer instance. +func newLexer(input string) *Lexer { + l := &Lexer{input: input} + l.readChar() + return l +} + +// LexerError represents an error that occurred during lexing. +type LexerError struct { + Char byte + Position int +} + +func (e LexerError) Error() string { + return fmt.Sprintf("unexpected character '%c' at position %d", e.Char, e.Position) +} + +func (l *Lexer) readChar() { + if l.readPos >= len(l.input) { + l.ch = 0 + } else { + l.ch = l.input[l.readPos] + } + l.pos = l.readPos + l.readPos++ +} + +func (l *Lexer) peekChar() byte { + if l.readPos >= len(l.input) { + return 0 + } + return l.input[l.readPos] +} + +func (l *Lexer) nextToken() (Token, string, error) { + l.skipWhitespace() + + switch l.ch { + case '(': + l.readChar() + return OpenParen, "", nil + case ')': + l.readChar() + return CloseParen, "", nil + case '&': + l.readChar() + return And, "", nil + case '|': + l.readChar() + return Or, "", nil + case 0: + return None, "", nil + case '"': + strLiteral := l.readString() + return AccessToken, strLiteral, nil + default: + if isLegalTokenLetter(l.ch) { + val := l.readIdentifier() + return AccessToken, val, nil + } else { + return Error, "", LexerError{Char: l.ch, Position: l.pos} + } + } +} + +func (l *Lexer) readIdentifier() string { + startPos := l.pos + for isLegalTokenLetter(l.ch) { + l.readChar() + } + return l.input[startPos:l.pos] +} + +func (l *Lexer) readString() string { + startPos := l.pos + 1 // Skip initial double quote + for { + l.readChar() + if l.ch == '"' || l.ch == 0 { + break // End of string or end of input + } + + // Handle escape sequences + if l.ch == '\\' { + l.readChar() + if l.ch != '"' && l.ch != '\\' { + // Handle invalid escape sequence + return l.input[startPos : l.pos-1] // Return string up to invalid escape + } + } + } + str := l.input[startPos:l.pos] + l.readChar() // Skip closing double quote + return str +} + +func (l *Lexer) skipWhitespace() { + for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' { + l.readChar() + } +} + +func isLegalTokenLetter(ch byte) bool { + + return unicode.IsLetter(rune(ch)) || + unicode.IsDigit(rune(ch)) || + ch == '_' || + ch == '-' || + ch == '.' || + ch == ':' +} diff --git a/pkg/lexer_test.go b/pkg/lexer_test.go new file mode 100644 index 0000000..8c00106 --- /dev/null +++ b/pkg/lexer_test.go @@ -0,0 +1,48 @@ +package pkg + +import "testing" + +func TestLexer(t *testing.T) { + lexer := newLexer("(a & b) | c") + token, _, _ := lexer.nextToken() + if token != OpenParen { + t.Fatal("expected OpenParen") + } + token, val, _ := lexer.nextToken() + if token != AccessToken { + t.Fatal("expected AccessToken") + } + if val != "a" { + t.Fatal("expected a") + } + token, _, _ = lexer.nextToken() + if token != And { + t.Fatal("expected And") + } + token, val, _ = lexer.nextToken() + if token != AccessToken { + t.Fatal("expected AccessToken") + } + if val != "b" { + t.Fatal("expected b") + } + token, _, _ = lexer.nextToken() + if token != CloseParen { + t.Fatal("expected CloseParen") + } + token, _, _ = lexer.nextToken() + if token != Or { + t.Fatal("expected Or") + } + token, val, _ = lexer.nextToken() + if token != AccessToken { + t.Fatal("expected AccessToken") + } + if val != "c" { + t.Fatal("expected c") + } + token, _, _ = lexer.nextToken() + if token != None { + t.Fatal("expected end of input") + } +} diff --git a/pkg/parser.go b/pkg/parser.go new file mode 100644 index 0000000..9766986 --- /dev/null +++ b/pkg/parser.go @@ -0,0 +1,182 @@ +// Package pkg Copyright 2024 Lars Wilhelmsen . All rights reserved. +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. +package pkg + +import "fmt" + +// Assuming Token and Lexer are already defined as in the previous lexer translation + +// ParserError represents errors that can occur during parsing. +type ParserError struct { + Message string +} + +func (e ParserError) Error() string { + return e.Message +} + +// AuthorizationExpression is an interface for different expression types. +type AuthorizationExpression interface { + Evaluate(authorizations map[string]bool) bool +} + +// AndExpression represents an AND expression. +type AndExpression struct { + Nodes []AuthorizationExpression +} + +func (a AndExpression) Evaluate(authorizations map[string]bool) bool { + for _, node := range a.Nodes { + if !node.Evaluate(authorizations) { + return false + } + } + return true +} + +// Implement Evaluate, ToJSONStr, and Normalize for AndExpression... + +// OrExpression represents an OR expression. +type OrExpression struct { + Nodes []AuthorizationExpression +} + +func (o OrExpression) Evaluate(authorizations map[string]bool) bool { + for _, node := range o.Nodes { + if node.Evaluate(authorizations) { + return true + } + } + return false +} + +// Implement Evaluate, ToJSONStr, and Normalize for OrExpression... + +// AccessTokenExpression represents an access token. +type AccessTokenExpression struct { + Token string +} + +func (a AccessTokenExpression) Evaluate(authorizations map[string]bool) bool { + return authorizations[a.Token] +} + +// Implement Evaluate, ToJSONStr, and Normalize for AccessTokenExpression... + +// Scope is used during parsing to build up expressions. +type Scope struct { + Nodes []AuthorizationExpression + Labels []AccessTokenExpression + Operator Token +} + +func newScope() *Scope { + return &Scope{ + Nodes: make([]AuthorizationExpression, 0), + Labels: make([]AccessTokenExpression, 0), + Operator: None, // Assuming None is a defined Token value + } +} + +func (s *Scope) addNode(node AuthorizationExpression) { + s.Nodes = append(s.Nodes, node) +} + +func (s *Scope) addLabel(label string) { + s.Labels = append(s.Labels, AccessTokenExpression{Token: label}) +} + +func (s *Scope) setOperator(operator Token) error { + if s.Operator != None { + return ParserError{Message: "unexpected operator"} + } + s.Operator = operator + return nil +} + +func (s *Scope) Build() (AuthorizationExpression, error) { + if len(s.Labels) == 1 && len(s.Nodes) == 0 { + return s.Labels[0], nil + } + + if len(s.Nodes) == 1 && len(s.Labels) == 0 { + return s.Nodes[0], nil + } + + if s.Operator == None { + return nil, ParserError{Message: "missing operator"} + } + // combine nodes and labels into one slice + combined := make([]AuthorizationExpression, 0, len(s.Nodes)+len(s.Labels)) + for _, node := range s.Nodes { + combined = append(combined, node) + } + for _, label := range s.Labels { + combined = append(combined, label) + } + if s.Operator == And { + return AndExpression{Nodes: combined}, nil + } + if s.Operator == Or { + return OrExpression{Nodes: combined}, nil + } + return nil, ParserError{Message: fmt.Sprintf("unexpected operator: %v", s.Operator)} +} + +// Parser is used to parse an expression and return an AuthorizationExpression tree. +type Parser struct { + Lexer *Lexer +} + +func NewParser(lexer *Lexer) *Parser { + return &Parser{Lexer: lexer} +} + +func (p *Parser) Parse() (AuthorizationExpression, error) { + scopeStack := []*Scope{newScope()} + + for { + tok, val, err := p.Lexer.nextToken() + if err != nil { + return nil, ParserError{Message: fmt.Sprintf("Lexer error: %v", err)} + } + + if tok == None { // Assuming 0 represents the end of input + break + } + + currentScope := scopeStack[len(scopeStack)-1] + + switch tok { + case AccessToken: + currentScope.addLabel(val) + case OpenParen: + newScope := newScope() + scopeStack = append(scopeStack, newScope) + case And, Or: + if err := currentScope.setOperator(tok); err != nil { + return nil, err + } + case CloseParen: + if len(scopeStack) == 1 { + return nil, ParserError{Message: "unmatched closing parenthesis"} + } + finishedScope := scopeStack[len(scopeStack)-1] + scopeStack = scopeStack[:len(scopeStack)-1] + expression, err := finishedScope.Build() + if err != nil { + return nil, err + } + currentScope = scopeStack[len(scopeStack)-1] + currentScope.addNode(expression) + default: + return nil, ParserError{Message: fmt.Sprintf("unexpected token: %v", tok)} + } + } + + if len(scopeStack) != 1 { + return nil, ParserError{Message: "mismatched parentheses"} + } + + return scopeStack[0].Build() +}