Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Llama 3.2 chat template #4

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Sources/Ast.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,7 @@ struct KeywordArgumentExpression: Expression {
var key: Identifier
var value: any Expression
}

struct NullLiteral: Literal {
var value: Any? = nil
}
4 changes: 3 additions & 1 deletion Sources/Environment.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Environment {
args[0] is UndefinedValue
},
"equalto": { _ in
throw JinjaError.syntaxNotSupported
throw JinjaError.syntaxNotSupported("equalto")
},
]

Expand Down Expand Up @@ -165,6 +165,8 @@ class Environment {
}

return ObjectValue(value: object)
case is NullValue:
return NullValue()
default:
throw JinjaError.runtime("Cannot convert to runtime value: \(input) type:\(type(of: input))")
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ enum JinjaError: Error, LocalizedError {
case parser(String)
case runtime(String)
case todo(String)
case syntaxNotSupported
case syntaxNotSupported(String)

var errorDescription: String? {
switch self {
case .syntax(let message): return "Syntax error: \(message)"
case .parser(let message): return "Parser error: \(message)"
case .runtime(let message): return "Runtime error: \(message)"
case .todo(let message): return "Todo error: \(message)"
case .syntaxNotSupported: return "Syntax not supported"
case .syntaxNotSupported(let string): return "Syntax not supported: \(string)"
}
}
}
4 changes: 4 additions & 0 deletions Sources/Lexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum TokenType: String {

case numericLiteral = "NumericLiteral"
case booleanLiteral = "BooleanLiteral"
case nullLiteral = "NullLiteral"
case stringLiteral = "StringLiteral"
case identifier = "Identifier"
case equals = "Equals"
Expand Down Expand Up @@ -69,8 +70,10 @@ let keywords: [String: TokenType] = [
"and": .and,
"or": .or,
"not": .not,
// Literals
"true": .booleanLiteral,
"false": .booleanLiteral,
"none": .nullLiteral,
]

func isWord(char: String) -> Bool {
Expand Down Expand Up @@ -226,6 +229,7 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions()
case .identifier,
.numericLiteral,
.booleanLiteral,
.nullLiteral,
.stringLiteral,
.closeParen,
.closeSquareBracket:
Expand Down
19 changes: 11 additions & 8 deletions Sources/Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,18 @@ func parse(tokens: [Token]) throws -> Program {
while typeof(.is) {
current += 1
let negate = typeof(.not)

if negate {
current += 1
}

var filter = try parsePrimaryExpression()

if let boolLiteralFlter = filter as? BoolLiteral {
filter = Identifier(value: String(boolLiteralFlter.value))
if let boolLiteralFilter = filter as? BoolLiteral {
filter = Identifier(value: String(boolLiteralFilter.value))
} else if filter is NullLiteral {
filter = Identifier(value: "none")
}

if let test = filter as? Identifier {
operand = TestExpression(operand: operand as! Expression, negate: negate, test: test)
}
else {
} else {
throw JinjaError.syntax("Expected identifier for the test")
}
}
Expand Down Expand Up @@ -373,6 +370,9 @@ func parse(tokens: [Token]) throws -> Program {
case .booleanLiteral:
current += 1
return BoolLiteral(value: token.value == "true")
case .nullLiteral:
current += 1
return NullLiteral()
case .identifier:
current += 1
return Identifier(value: token.value)
Expand Down Expand Up @@ -415,6 +415,9 @@ func parse(tokens: [Token]) throws -> Program {
current += 1

return ObjectLiteral(value: values)
case .nullLiteral:
current += 1
return NullLiteral()
default:
throw JinjaError.syntax("Unexpected token: \(token.type)")
}
Expand Down
52 changes: 40 additions & 12 deletions Sources/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ struct Interpreter {
throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))")
}
default:
throw JinjaError.syntaxNotSupported
throw JinjaError.syntaxNotSupported(String(describing: node.loopvar))
}

let evaluated = try self.evaluateBlock(statements: node.body, environment: scope)
Expand Down Expand Up @@ -353,21 +353,21 @@ struct Interpreter {
}
else if let left = left as? NumericValue, let right = right as? NumericValue {
switch node.operation.value {
case "+": throw JinjaError.syntaxNotSupported
case "-": throw JinjaError.syntaxNotSupported
case "*": throw JinjaError.syntaxNotSupported
case "/": throw JinjaError.syntaxNotSupported
case "+": throw JinjaError.syntaxNotSupported("+")
case "-": throw JinjaError.syntaxNotSupported("-")
case "*": throw JinjaError.syntaxNotSupported("*")
case "/": throw JinjaError.syntaxNotSupported("/")
case "%":
switch left.value {
case is Int:
return NumericValue(value: left.value as! Int % (right.value as! Int))
default:
throw JinjaError.runtime("Unknown value type:\(type(of: left.value))")
}
case "<": throw JinjaError.syntaxNotSupported
case ">": throw JinjaError.syntaxNotSupported
case ">=": throw JinjaError.syntaxNotSupported
case "<=": throw JinjaError.syntaxNotSupported
case "<": throw JinjaError.syntaxNotSupported("<")
case ">": throw JinjaError.syntaxNotSupported(">")
case ">=": throw JinjaError.syntaxNotSupported(">=")
case "<=": throw JinjaError.syntaxNotSupported("<=")
default:
throw JinjaError.runtime("Unknown operation type:\(node.operation.value)")
}
Expand All @@ -380,7 +380,7 @@ struct Interpreter {
}
}
else if right is ArrayValue {
throw JinjaError.syntaxNotSupported
throw JinjaError.syntaxNotSupported("right is ArrayValue")
}

if left is StringValue || right is StringValue {
Expand Down Expand Up @@ -428,7 +428,20 @@ struct Interpreter {
}

if left is StringValue, right is ObjectValue {
throw JinjaError.syntaxNotSupported
switch node.operation.value {
case "in":
if let leftString = (left as? StringValue)?.value,
let rightObject = right as? ObjectValue {
return BooleanValue(value: rightObject.value.keys.contains(leftString))
}
case "not in":
if let leftString = (left as? StringValue)?.value,
let rightObject = right as? ObjectValue {
return BooleanValue(value: !rightObject.value.keys.contains(leftString))
}
default:
throw JinjaError.runtime("Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue")
}
}

throw JinjaError.syntax(
Expand Down Expand Up @@ -664,6 +677,17 @@ struct Interpreter {
throw JinjaError.runtime("Unknown filter: \(node.filter)")
}

func evaluateTestExpression(node: TestExpression, environment: Environment) throws -> any RuntimeValue {
let operand = try self.evaluate(statement: node.operand, environment: environment)

if let testFunction = environment.tests[node.test.value] {
let result = try testFunction(operand)
return BooleanValue(value: node.negate ? !result : result)
} else {
throw JinjaError.runtime("Unknown test: \(node.test.value)")
}
}

func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue {
if let statement {
switch statement {
Expand Down Expand Up @@ -693,8 +717,12 @@ struct Interpreter {
return BooleanValue(value: statement.value)
case let statement as FilterExpression:
return try self.evaluateFilterExpression(node: statement, environment: environment)
case let statement as TestExpression:
return try self.evaluateTestExpression(node: statement, environment: environment)
case is NullLiteral:
return NullValue()
default:
throw JinjaError.runtime("Unknown node type: \(type(of:statement))")
throw JinjaError.runtime("Unknown node type: \(type(of:statement)), statement: \(String(describing: statement))")
}
}
else {
Expand Down
1 change: 1 addition & 0 deletions Sources/Template.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public struct Template {

try env.set(name: "false", value: false)
try env.set(name: "true", value: true)
try env.set(name: "none", value: NullValue())
try env.set(
name: "raise_exception",
value: { (args: String) throws in
Expand Down
44 changes: 43 additions & 1 deletion Tests/LexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ final class LexerTests: XCTestCase {
"UNDEFINED_VARIABLES": "{{ undefined_variable }}",
"UNDEFINED_ACCESS": "{{ object.undefined_attribute }}",

// Null
"NULL_VARIABLE": "{% if not null_val is defined %}{% set null_val = none %}{% endif %}{% if null_val is not none %}{{ 'fail' }}{% else %}{{ 'pass' }}{% endif %}",

// Ternary operator
"TERNARY_OPERATOR":
"|{{ 'a' if true else 'b' }}|{{ 'a' if false else 'b' }}|{{ 'a' if 1 + 1 == 2 else 'b' }}|{{ 'a' if 1 + 1 == 3 or 1 * 2 == 3 else 'b' }}|",
Expand Down Expand Up @@ -2032,7 +2035,7 @@ final class LexerTests: XCTestCase {
Token(value: "unknown", type: .stringLiteral),
Token(value: ")", type: .closeParen),
Token(value: "is", type: .is),
Token(value: "none", type: .identifier),
Token(value: "none", type: .nullLiteral),
Token(value: "}}", type: .closeExpression),
Token(value: "|", type: .text),
Token(value: "{{", type: .openExpression),
Expand Down Expand Up @@ -2177,6 +2180,45 @@ final class LexerTests: XCTestCase {
Token(value: "}}", type: .closeExpression),
],

// Null
"NULL_VARIABLE": [
Token(value: "{%", type: .openStatement),
Token(value: "if", type: .if),
Token(value: "not", type: .not),
Token(value: "null_val", type: .identifier),
Token(value: "is", type: .is),
Token(value: "defined", type: .identifier),
Token(value: "%}", type: .closeStatement),
Token(value: "{%", type: .openStatement),
Token(value: "set", type: .set),
Token(value: "null_val", type: .identifier),
Token(value: "=", type: .equals),
Token(value: "none", type: .nullLiteral),
Token(value: "%}", type: .closeStatement),
Token(value: "{%", type: .openStatement),
Token(value: "endif", type: .endIf),
Token(value: "%}", type: .closeStatement),
Token(value: "{%", type: .openStatement),
Token(value: "if", type: .if),
Token(value: "null_val", type: .identifier),
Token(value: "is", type: .is),
Token(value: "not", type: .not),
Token(value: "none", type: .nullLiteral),
Token(value: "%}", type: .closeStatement),
Token(value: "{{", type: .openExpression),
Token(value: "fail", type: .stringLiteral),
Token(value: "}}", type: .closeExpression),
Token(value: "{%", type: .openStatement),
Token(value: "else", type: .else),
Token(value: "%}", type: .closeStatement),
Token(value: "{{", type: .openExpression),
Token(value: "pass", type: .stringLiteral),
Token(value: "}}", type: .closeExpression),
Token(value: "{%", type: .openStatement),
Token(value: "endif", type: .endIf),
Token(value: "%}", type: .closeStatement),
],

// Ternary operator
"TERNARY_OPERATOR": [
Token(value: "|", type: .text),
Expand Down