diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..cb9509a --- /dev/null +++ b/Package.resolved @@ -0,0 +1,14 @@ +{ + "pins" : [ + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", + "version" : "1.1.4" + } + } + ], + "version" : 2 +} diff --git a/Package.swift b/Package.swift index e7074aa..80ada54 100644 --- a/Package.swift +++ b/Package.swift @@ -13,17 +13,25 @@ let package = Package( targets: ["Jinja"] ) ], + dependencies: [ + .package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4") + ], targets: [ // Targets are the basic building blocks of a package, defining a module or a test suite. // Targets can depend on other targets in this package and products from dependencies. .target( name: "Jinja", + dependencies: [ + .product(name: "OrderedCollections", package: "swift-collections") + ], path: "Sources", swiftSettings: [.enableUpcomingFeature("BareSlashRegexLiterals")] ), .testTarget( name: "JinjaTests", - dependencies: ["Jinja"], + dependencies: [ + "Jinja" + ], path: "Tests", swiftSettings: [.enableUpcomingFeature("BareSlashRegexLiterals")] ), diff --git a/Sources/Ast.swift b/Sources/Ast.swift index 7460284..45357e7 100644 --- a/Sources/Ast.swift +++ b/Sources/Ast.swift @@ -6,6 +6,7 @@ // import Foundation +import OrderedCollections protocol Statement {} @@ -41,7 +42,7 @@ struct TupleLiteral: Literal { } struct ObjectLiteral: Literal { - var value: [(Expression, Expression)] + var value: OrderedDictionary } struct Set: Statement { @@ -49,7 +50,7 @@ struct Set: Statement { var value: Expression } -struct If: Statement { +struct If: Statement, Expression { var test: Expression var body: [Statement] var alternate: [Statement] @@ -59,14 +60,14 @@ struct Identifier: Expression { var value: String } -protocol Loopvar {} -extension Identifier: Loopvar {} -extension TupleLiteral: Loopvar {} +typealias Loopvar = Expression struct For: Statement { var loopvar: Loopvar var iterable: Expression var body: [Statement] + var defaultBlock: [Statement] + var ifCondition: Expression? } struct MemberExpression: Expression { @@ -92,7 +93,11 @@ extension CallExpression: Filter {} struct FilterExpression: Expression { var operand: Expression - var filter: Filter + var filter: Identifier + var args: [Expression] + var kwargs: [KeywordArgumentExpression] + var dyn_args: Expression? + var dyn_kwargs: Expression? } struct TestExpression: Expression { @@ -124,3 +129,23 @@ struct KeywordArgumentExpression: Expression { struct NullLiteral: Literal { var value: Any? = nil } + +struct SelectExpression: Expression { + var iterable: Expression + var test: Expression +} + +struct Macro: Statement { + var name: Identifier + var args: [Expression] + var body: [Statement] +} + +struct KeywordArgumentsValue: RuntimeValue { + var value: [String: any RuntimeValue] + var builtins: [String: any RuntimeValue] = [:] + + func bool() -> Bool { + !value.isEmpty + } +} diff --git a/Sources/Environment.swift b/Sources/Environment.swift index c845068..10819b1 100644 --- a/Sources/Environment.swift +++ b/Sources/Environment.swift @@ -6,49 +6,71 @@ // import Foundation +import OrderedCollections class Environment { var parent: Environment? var variables: [String: any RuntimeValue] = [ "namespace": FunctionValue(value: { args, _ in - if args.count == 0 { + if args.isEmpty { return ObjectValue(value: [:]) } - - if args.count != 1 || !(args[0] is ObjectValue) { + guard args.count == 1, let objectArg = args[0] as? ObjectValue else { throw JinjaError.runtime("`namespace` expects either zero arguments or a single object argument") } - - return args[0] + return objectArg }) ] - var tests: [String: (any RuntimeValue...) throws -> Bool] = [ - "boolean": { - args in - args[0] is BooleanValue - }, - - "callable": { - args in - args[0] is FunctionValue - }, - - "odd": { - args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 != 0 + lazy var tests: [String: (any RuntimeValue...) throws -> Bool] = [ + "odd": { args in + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 != 0 } else { - throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of: args.first))") } }, "even": { args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 == 0 + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 == 0 } else { - throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of: args.first))") + } + }, + "divisibleby": { args in + guard let value = args[0] as? NumericValue, + let num = args[1] as? NumericValue, + let intValue = value.value as? Int, + let intNum = num.value as? Int + else { + throw JinjaError.runtime("divisibleby test requires two integers") + } + return intValue % intNum == 0 + }, + "defined": { args in + return !(args[0] is UndefinedValue) + }, + "undefined": { args in + return args[0] is UndefinedValue + }, + "filter": { [weak self] (args: any RuntimeValue...) throws -> Bool in + guard let name = args[0] as? StringValue else { + throw JinjaError.runtime("filter test requires a string") + } + return self?.filters.keys.contains(name.value) ?? false + }, + "test": { [weak self] (args: any RuntimeValue...) throws -> Bool in + guard let name = args[0] as? StringValue else { + throw JinjaError.runtime("test test requires a string") } + return self?.tests.keys.contains(name.value) ?? false + }, + "none": { args in + return args[0] is NullValue + }, + "boolean": { args in + return args[0] is BooleanValue }, "false": { args in if let arg = args[0] as? BooleanValue { @@ -62,24 +84,22 @@ class Environment { } return false }, - "number": { args in - args[0] is NumericValue - }, "integer": { args in if let arg = args[0] as? NumericValue { return arg.value is Int } - return false }, - "iterable": { args in - args[0] is ArrayValue || args[0] is StringValue + "float": { args in + if let numericValue = args[0] as? NumericValue { + return numericValue.value is Double + } + return false }, "lower": { args in if let arg = args[0] as? StringValue { return arg.value == arg.value.lowercased() } - return false }, "upper": { args in @@ -88,17 +108,991 @@ class Environment { } return false }, - "none": { args in - args[0] is NullValue + "string": { args in + return args[0] is StringValue }, - "defined": { args in - !(args[0] is UndefinedValue) + "mapping": { args in + return args[0] is ObjectValue }, - "undefined": { args in - args[0] is UndefinedValue + "number": { args in + return args[0] is NumericValue + }, + "sequence": { args in + let value = args[0] + if value is ArrayValue || value is StringValue { + return true + } + return false + }, + "iterable": { args in + return args[0] is ArrayValue || args[0] is StringValue || args[0] is ObjectValue + }, + "callable": { args in + return args[0] is FunctionValue + }, + // TODO: Implement "sameas" + // TODO: Implement "escaped" + "in": { args in + guard let seq = args[1] as? ArrayValue else { + throw JinjaError.runtime("in test requires a sequence") + } + return seq.value.contains { item in + self.doEqualTo([args[0], item]) + } + }, + "==": { args in self.doEqualTo(args) }, + "eq": { args in self.doEqualTo(args) }, + "equalto": { args in self.doEqualTo(args) }, + "!=": { args in + guard args.count == 2 else { + throw JinjaError.runtime("!= test requires two arguments") + } + return !self.doEqualTo(args) + }, + "ne": { args in + guard args.count == 2 else { + throw JinjaError.runtime("ne test requires two arguments") + } + return !self.doEqualTo(args) + }, + ">": { args in + guard args.count == 2 else { + throw JinjaError.runtime("> test requires two arguments") + } + return try self.doGreaterThan(args) + }, + "gt": { args in + guard args.count == 2 else { + throw JinjaError.runtime("gt test requires two arguments") + } + return try self.doGreaterThan(args) + }, + "greaterthan": { args in + guard args.count == 2 else { + throw JinjaError.runtime("greaterthan test requires two arguments") + } + return try self.doGreaterThan(args) + }, + ">=": { args in + guard args.count == 2 else { + throw JinjaError.runtime(">= test requires two arguments") + } + return try self.doGreaterThanOrEqual(args) + }, + "ge": { args in + guard args.count == 2 else { + throw JinjaError.runtime("ge test requires two arguments") + } + return try self.doGreaterThanOrEqual(args) + }, + "<": { args in + guard args.count == 2 else { + throw JinjaError.runtime("< test requires two arguments") + } + return try self.doLessThan(args) + }, + "lt": { args in + guard args.count == 2 else { + throw JinjaError.runtime("lt test requires two arguments") + } + return try self.doLessThan(args) + }, + "lessthan": { args in + guard args.count == 2 else { + throw JinjaError.runtime("lessthan test requires two arguments") + } + return try self.doLessThan(args) + }, + "<=": { args in + guard args.count == 2 else { + throw JinjaError.runtime("<= test requires two arguments") + } + return try self.doLessThanOrEqual(args) }, - "equalto": { _ in - throw JinjaError.syntaxNotSupported("equalto") + "le": { args in + guard args.count == 2 else { + throw JinjaError.runtime("le test requires two arguments") + } + return try self.doLessThanOrEqual(args) + }, + ] + + lazy var filters: [String: ([any RuntimeValue], Environment) throws -> any RuntimeValue] = [ + "abs": { args, env in + guard let numericValue = args[0] as? NumericValue else { + throw JinjaError.runtime("abs filter requires a number") + } + if let intValue = numericValue.value as? Int { + return NumericValue(value: abs(intValue)) + } else if let doubleValue = numericValue.value as? Double { + return NumericValue(value: abs(doubleValue)) + } else { + throw JinjaError.runtime("Unsupported numeric type for abs filter") + } + }, + "attr": { args, env in + guard let name = args[1] as? StringValue else { + throw JinjaError.runtime("attr filter requires an object and attribute name") + } + let obj = args[0] + if let objValue = obj as? ObjectValue { + return objValue.value[name.value] ?? UndefinedValue() + } + return UndefinedValue() + }, + "batch": { args, env in + guard let arrayValue = args[0] as? ArrayValue, + let linecount = args[1] as? NumericValue, + let count = linecount.value as? Int + else { + throw JinjaError.runtime("batch filter requires an array and line count") + } + let fillWith = args.count > 2 ? args[2] : nil + var result: [[any RuntimeValue]] = [] + var temp: [any RuntimeValue] = [] + for item in arrayValue.value { + if temp.count == count { + result.append(temp) + temp = [] + } + temp.append(item) + } + if !temp.isEmpty { + if let fill = fillWith, temp.count < count { + temp += Array(repeating: fill, count: count - temp.count) + } + result.append(temp) + } + return ArrayValue(value: result.map { ArrayValue(value: $0) }) + }, + "capitalize": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("capitalize filter requires a string") + } + return StringValue(value: stringValue.value.capitalized) + }, + "center": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("center filter requires a string") + } + let width = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 80 + let padding = max(0, width - stringValue.value.count) + let leftPadding = padding / 2 + let rightPadding = padding - leftPadding + return StringValue( + value: String(repeating: " ", count: leftPadding) + stringValue.value + + String(repeating: " ", count: rightPadding) + ) + }, + "count": { args, env in + let value = args[0] + if let arrayValue = value as? ArrayValue { + return NumericValue(value: arrayValue.value.count) + } else if let stringValue = value as? StringValue { + return NumericValue(value: stringValue.value.count) + } else if let objectValue = value as? ObjectValue { + return NumericValue(value: objectValue.value.count) + } + throw JinjaError.runtime("Cannot count value of type \(type(of: value))") + }, + "d": { args, env in try self.doDefault(args, env) }, + "default": { args, env in try self.doDefault(args, env) }, + "dictsort": { args, env in + guard let dict = args[0] as? ObjectValue else { + throw JinjaError.runtime("dictsort filter requires a dictionary") + } + let caseSensitive = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let by = args.count > 2 ? (args[2] as? StringValue)?.value ?? "key" : "key" + let reverse = args.count > 3 ? (args[3] as? BooleanValue)?.value ?? false : false + let sortedDict = try dict.storage.sorted { (item1, item2) in + let a: Any, b: Any + if by == "key" { + a = item1.key + b = item2.key + } else if by == "value" { + a = item1.value + b = item2.value + } else { + throw JinjaError.runtime("Invalid 'by' argument for dictsort filter") + } + let result: Bool + if let aString = a as? String, let bString = b as? String { + result = caseSensitive ? aString < bString : aString.lowercased() < bString.lowercased() + } else if let aNumeric = a as? NumericValue, let bNumeric = b as? NumericValue { + if let aInt = aNumeric.value as? Int, let bInt = bNumeric.value as? Int { + result = aInt < bInt + } else if let aDouble = aNumeric.value as? Double, let bDouble = bNumeric.value as? Double { + result = aDouble < bDouble + } else { + throw JinjaError.runtime("Cannot compare values in dictsort filter") + } + } else { + throw JinjaError.runtime("Cannot compare values in dictsort filter") + } + return reverse ? !result : result + } + return ArrayValue( + value: sortedDict.map { (key, value) in + return ArrayValue(value: [StringValue(value: key), value]) + } + ) + }, + "e": { args, env in try self.doEscape(args, env) }, + "escape": { args, env in try self.doEscape(args, env) }, + "filesizeformat": { args, env in + guard let value = args[0] as? NumericValue, let size = value.value as? Double else { + throw JinjaError.runtime("filesizeformat filter requires a numeric value") + } + let binary = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let units = + binary + ? [" KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"] + : [" kB", " MB", " GB", " TB", " PB", " EB", " ZB", " YB"] + let base: Double = binary ? 1024.0 : 1000.0 + if size < 1.0 { + return StringValue(value: "\(Int(size)) Byte") // Fixed: Wrap String in StringValue + } + let i = Int(floor(log(size) / log(base))) + let unit = units[min(i, units.count - 1)] + let num = size / pow(base, Double(i)) + return StringValue(value: String(format: "%.1f%@", num, unit)) // Fixed: Wrap String in StringValue + }, + "first": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("first filter requires an array") + } + return arrayValue.value.first ?? UndefinedValue() + }, + "float": { args, env in + guard let value = args[0] as? NumericValue else { + return NumericValue(value: 0.0) + } + if let doubleValue = value.value as? Double { + return NumericValue(value: doubleValue) + } else if let intValue = value.value as? Int { + return NumericValue(value: Double(intValue)) + } else { + return NumericValue(value: 0.0) + } + }, + "forceescape": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("forceescape filter requires a string") + } + return StringValue( + value: stringValue.value.replacingOccurrences(of: "&", with: "&") + .replacingOccurrences(of: "<", with: "<") + .replacingOccurrences(of: ">", with: ">") + .replacingOccurrences(of: "\"", with: """) + .replacingOccurrences(of: "'", with: "'") + ) + }, + "format": { args, env in + guard let formatString = args[0] as? StringValue else { + throw JinjaError.runtime("format filter requires a format string") + } + let values = args.dropFirst().map { $0 as? StringValue } + let formattedString = String(format: formatString.value, arguments: values.map { $0?.value ?? "" }) + return StringValue(value: formattedString) + }, + "groupby": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("groupby filter requires an array") + } + guard let attribute = args[1] as? StringValue else { + throw JinjaError.runtime("groupby filter requires an attribute name") + } + let caseSensitive = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + var groups: [String: [any RuntimeValue]] = [:] + for item in arrayValue.value { + guard let objectValue = item as? ObjectValue, + let groupKey = objectValue.value[attribute.value] as? StringValue + else { + continue + } + let key = caseSensitive ? groupKey.value : groupKey.value.lowercased() + groups[key, default: []].append(item) + } + return ArrayValue( + value: groups.map { (key, value) in + return ObjectValue(value: ["grouper": StringValue(value: key), "list": ArrayValue(value: value)]) + } + ) + }, + "indent": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("indent filter requires a string") + } + let width = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 4 + let indent = String(repeating: " ", count: width) + let first = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + let blank = args.count > 3 ? (args[3] as? BooleanValue)?.value ?? false : false + var lines = stringValue.value.split(separator: "\n", omittingEmptySubsequences: false) + for i in lines.indices { + if (first || i > 0) && (blank || !lines[i].isEmpty) { + lines[i] = Substring(indent + lines[i]) + } + } + return StringValue(value: lines.joined(separator: "\n")) + }, + "int": { args, env in + guard let value = args[0] as? NumericValue else { + return NumericValue(value: 0) + } + if let intValue = value.value as? Int { + return NumericValue(value: intValue) + } else if let doubleValue = value.value as? Double { + return NumericValue(value: Int(doubleValue)) + } else { + return NumericValue(value: 0) + } + }, + "items": { args, env in + guard let iterable = args.first else { + throw JinjaError.runtime("items filter requires an iterable") + } + if let arrayValue = iterable as? ArrayValue { + return ArrayValue( + value: arrayValue.value.map { + ArrayValue(value: [$0]) + } + ) + } else if let objectValue = iterable as? ObjectValue { + return ArrayValue( + value: objectValue.storage.map { (key, value) in + ArrayValue(value: [StringValue(value: key), value]) + } + ) + } else { + throw JinjaError.runtime("items filter can only be applied to arrays and objects") + } + }, + "join": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("join filter requires an array") + } + let separator = (args.count > 1 && args[1] is StringValue) ? (args[1] as! StringValue).value : "" + let stringValues = arrayValue.value.compactMap { $0 as? StringValue } + return StringValue(value: stringValues.map { $0.value }.joined(separator: separator)) + }, + "last": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("last filter requires an array") + } + return arrayValue.value.last ?? UndefinedValue() + }, + "length": { args, env in + guard let arg = args.first else { + throw JinjaError.runtime("length filter expects one argument") + } + + if let arrayValue = arg as? ArrayValue { + return NumericValue(value: arrayValue.value.count) + } else if let stringValue = arg as? StringValue { + return NumericValue(value: stringValue.value.count) + } else if let objectValue = arg as? ObjectValue { + return NumericValue(value: objectValue.value.count) + } else { + throw JinjaError.runtime("Cannot get length of type: \(type(of: arg))") + } + }, + "list": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("list filter requires an array") + } + return arrayValue + }, + "lower": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("lower filter requires a string") + } + return StringValue(value: stringValue.value.lowercased()) + }, + "map": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("map filter requires an array") + } + // If no attribute is provided, return the array as is + if args.count == 1 { + return arrayValue + } + // Handle attribute mapping + if let attribute = args[1] as? StringValue { + let values = arrayValue.value.compactMap { item -> (any RuntimeValue)? in + if let objectValue = item as? ObjectValue { + return objectValue.value[attribute.value] + } + return nil + } + return ArrayValue(value: values) + } + // Handle function mapping + if let function = args[1] as? FunctionValue { + let values = try arrayValue.value.map { item in + try function.value([item], env) + } + return ArrayValue(value: values) + } + throw JinjaError.runtime("map filter requires either an attribute name or a function") + }, + "min": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("min filter requires an array") + } + if arrayValue.value.isEmpty { + return UndefinedValue() + } + if let numericValues = arrayValue.value as? [NumericValue] { + let numbers = numericValues.compactMap { $0.value as? Double } + if numbers.count != numericValues.count { + throw JinjaError.runtime("min filter requires all array elements to be numbers") + } + return NumericValue(value: numbers.min() ?? 0) + } else if let stringValues = arrayValue.value as? [StringValue] { + return StringValue(value: stringValues.map { $0.value }.min() ?? "") + } else { + throw JinjaError.runtime("min filter requires an array of numbers or strings") + } + }, + "max": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("max filter requires an array") + } + if arrayValue.value.isEmpty { + return UndefinedValue() + } + if let numericValues = arrayValue.value as? [NumericValue] { + let numbers = numericValues.compactMap { $0.value as? Double } + if numbers.count != numericValues.count { + throw JinjaError.runtime("max filter requires all array elements to be numbers") + } + return NumericValue(value: numbers.max() ?? 0) + } else if let stringValues = arrayValue.value as? [StringValue] { + return StringValue(value: stringValues.map { $0.value }.max() ?? "") + } else { + throw JinjaError.runtime("max filter requires an array of numbers or strings") + } + }, + "pprint": { args, env in + guard let value = args.first else { + throw JinjaError.runtime("pprint filter expects one argument") + } + return StringValue(value: String(describing: value)) + }, + "random": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("random filter requires an array") + } + if let randomIndex = arrayValue.value.indices.randomElement() { + return arrayValue.value[randomIndex] + } else { + return UndefinedValue() + } + }, + "reject": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("reject filter requires an array") + } + guard let testName = args[1] as? StringValue else { + throw JinjaError.runtime("reject filter requires a test name") + } + guard let test = env.tests[testName.value] else { + throw JinjaError.runtime("Unknown test '\(testName.value)'") + } + var result: [any RuntimeValue] = [] + for item in arrayValue.value { + // Correctly pass arguments to the test function + if try !test(item) { // Negate the result for 'reject' + result.append(item) + } + } + return ArrayValue(value: result) + }, + "rejectattr": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("rejectattr filter requires an array") + } + guard let attribute = args[1] as? StringValue else { + throw JinjaError.runtime("rejectattr filter requires an attribute name") + } + var result: [any RuntimeValue] = [] + for item in arrayValue.value { + guard let objectValue = item as? ObjectValue, + let attrValue = objectValue.value[attribute.value] + else { + continue + } + if args.count == 2 { + if !attrValue.bool() { + result.append(item) + } + } else { + let testName = (args[2] as? StringValue)?.value ?? "defined" + guard let test = env.tests[testName] else { + throw JinjaError.runtime("Unknown test '\(testName)'") + } + // Correctly pass arguments to the test function + if try !test(attrValue) { // Note the negation (!) for rejectattr + result.append(item) + } + } + } + return ArrayValue(value: result) + }, + "replace": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("replace filter requires a string") + } + guard let oldValue = args[1] as? StringValue else { + throw JinjaError.runtime("replace filter requires an old value string") + } + guard let newValue = args[2] as? StringValue else { + throw JinjaError.runtime("replace filter requires a new value string") + } + let count = (args.count > 3 && args[3] is NumericValue) ? (args[3] as! NumericValue).value as! Int : Int.max + return StringValue( + value: stringValue.value.replacingOccurrences( + of: oldValue.value, + with: newValue.value, + options: [], + range: nil + ) + ) + }, + "reverse": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("reverse filter requires an array") + } + return ArrayValue(value: arrayValue.value.reversed()) + }, + "round": { args, env in + guard let value = args[0] as? NumericValue, let number = value.value as? Double else { + throw JinjaError.runtime("round filter requires a number") + } + let precision = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 0 + let method = (args.count > 2 && args[2] is StringValue) ? (args[2] as! StringValue).value : "common" + let factor = pow(10, Double(precision)) + let roundedNumber: Double + if method == "common" { + roundedNumber = round(number * factor) / factor + } else if method == "ceil" { + roundedNumber = ceil(number * factor) / factor + } else if method == "floor" { + roundedNumber = floor(number * factor) / factor + } else { + throw JinjaError.runtime("Invalid method for round filter") + } + return NumericValue(value: roundedNumber) + }, + "safe": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("safe filter requires a string") + } + return stringValue // In this minimal example, we don't handle marking strings as safe + }, + "select": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("select filter requires an array") + } + guard let testName = args[1] as? StringValue else { + throw JinjaError.runtime("select filter requires a test name") + } + guard let test = env.tests[testName.value] else { + throw JinjaError.runtime("Unknown test '\(testName.value)'") + } + var result: [any RuntimeValue] = [] + for item in arrayValue.value { + if try test(item) { + result.append(item) + } + } + return ArrayValue(value: result) + }, + "selectattr": { args, env in + guard let array = args[0] as? ArrayValue else { + throw JinjaError.runtime("selectattr filter requires an array") + } + guard let attribute = args[1] as? StringValue else { + throw JinjaError.runtime("selectattr filter requires an attribute name") + } + guard args.count > 2 else { + throw JinjaError.runtime("selectattr filter requires a test") + } + var result: [any RuntimeValue] = [] + for item in array.value { + if let obj = item as? ObjectValue, + let attrValue = obj.value[attribute.value] + { + if args[2] is StringValue && args[2].bool() { + // Simple boolean test + if attrValue.bool() { + result.append(item) + } + } else if args.count > 3 { + // Test with comparison value + if let testName = (args[2] as? StringValue)?.value { + let testValue = args[3] + if testName == "equalto" { + // Handle equality test + if let strAttr = attrValue as? StringValue, + let strTest = testValue as? StringValue + { + if strAttr.value == strTest.value { + result.append(item) + } + } + } + } + } + } + } + return ArrayValue(value: result) + }, + "slice": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("slice filter requires an array") + } + guard let slices = args[1] as? NumericValue, let numSlices = slices.value as? Int else { + throw JinjaError.runtime("slice filter requires a number of slices") + } + let fillWith = args.count > 2 ? args[2] : nil + let itemsPerSlice = arrayValue.value.count / numSlices + let slicesWithExtra = arrayValue.value.count % numSlices + var result: [[any RuntimeValue]] = [] + var startIndex = 0 + for i in 0 ..< numSlices { + let count = itemsPerSlice + (i < slicesWithExtra ? 1 : 0) + var slice = Array(arrayValue.value[startIndex ..< startIndex + count]) + if let fillWithValue = fillWith, i >= slicesWithExtra, slice.count < itemsPerSlice { + slice.append(fillWithValue) + } + result.append(slice) + startIndex += count + } + return ArrayValue(value: result.map { ArrayValue(value: $0) }) + }, + "sort": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("sort filter requires an array") + } + let reverse = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let caseSensitive = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + let attribute = args.count > 3 ? (args[3] as? StringValue)?.value : nil + let sortedArray = try arrayValue.value.sorted { (a, b) in + let aValue: Any + let bValue: Any + if let attribute = attribute { + guard let aObject = a as? ObjectValue, let bObject = b as? ObjectValue else { + throw JinjaError.runtime("sort filter with attribute requires an array of objects") + } + guard let aAttr = aObject.value[attribute], let bAttr = bObject.value[attribute] else { + throw JinjaError.runtime("sort filter could not get attribute from both objects") + } + aValue = aAttr + bValue = bAttr + } else { + aValue = a + bValue = b + } + let result: Bool + if let aString = aValue as? StringValue, let bString = bValue as? StringValue { + result = + caseSensitive + ? aString.value < bString.value : aString.value.lowercased() < bString.value.lowercased() + } else if let aNumeric = aValue as? NumericValue, let bNumeric = bValue as? NumericValue { + if let aInt = aNumeric.value as? Int, let bInt = bNumeric.value as? Int { + result = aInt < bInt + } else if let aDouble = aNumeric.value as? Double, let bDouble = bNumeric.value as? Double { + result = aDouble < bDouble + } else { + throw JinjaError.runtime("Cannot compare values in sort filter") + } + } else { + throw JinjaError.runtime("Cannot compare values in sort filter") + } + return reverse ? !result : result + } + return ArrayValue(value: sortedArray) + }, + "string": { args, env in + guard let arg = args.first else { + throw JinjaError.runtime("string filter expects one argument") + } + // In Jinja2 in Python, the `string` filter calls Python's `str` function on dicts, which which uses single quotes for strings. Here we're using double quotes in `tojson`, which is probably better for LLMs anyway, but this will result in differences with output from Jinja2. + return try StringValue(value: stringify(arg, whitespaceControl: true)) + }, + "striptags": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("striptags filter requires a string") + } + // A very basic implementation to remove HTML tags + let tagPattern = #"<[^>]+>"# + let noTagsString = stringValue.value.replacingOccurrences( + of: tagPattern, + with: "", + options: .regularExpression + ) + return StringValue(value: noTagsString) + }, + "sum": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("sum filter requires an array") + } + let attribute = (args.count > 1 && args[1] is StringValue) ? (args[1] as! StringValue).value : nil + let start = (args.count > 2 && args[2] is NumericValue) ? (args[2] as! NumericValue).value as! Double : 0.0 + + var sum: Double = start + for item in arrayValue.value { + if let attribute = attribute, let objectValue = item as? ObjectValue, + let attrValue = objectValue.value[attribute] as? NumericValue + { + if let intValue = attrValue.value as? Int { + sum += Double(intValue) + } else if let doubleValue = attrValue.value as? Double { + sum += doubleValue + } + } else if let numericValue = item as? NumericValue { + if let intValue = numericValue.value as? Int { + sum += Double(intValue) + } else if let doubleValue = numericValue.value as? Double { + sum += doubleValue + } + } + } + return NumericValue(value: sum) + }, + "title": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("title filter requires a string") + } + return StringValue(value: stringValue.value.capitalized) + }, + "trim": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("trim filter requires a string") + } + return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) + }, + "truncate": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("truncate filter requires a string") + } + let length = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 255 + let killwords = (args.count > 2 && args[2] is BooleanValue) ? (args[2] as! BooleanValue).value : false + let end = (args.count > 3 && args[3] is StringValue) ? (args[3] as! StringValue).value : "..." + if stringValue.value.count <= length { + return stringValue + } + if killwords { + return StringValue(value: String(stringValue.value.prefix(length - end.count)) + end) + } else { + let truncated = String(stringValue.value.prefix(length - end.count)) + if let lastSpace = truncated.lastIndex(of: " ") { + return StringValue(value: String(truncated[.. [any RuntimeValue] { + switch value { + case let arrayValue as ArrayValue: + return arrayValue.value + case let stringValue as StringValue: + return stringValue.value.map { StringValue(value: String($0)) } + case let objectValue as ObjectValue: + return objectValue.storage.map { key, value in + ArrayValue(value: [StringValue(value: key), value]) + } + default: + throw JinjaError.runtime("Value must be iterable (array, string, or object)") + } + } + // Get the input iterable + guard let input = args.first else { + throw JinjaError.runtime("unique filter requires an iterable") + } + let caseSensitive = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let attribute = args.count > 2 ? args[2] : nil + // Enhanced getter function to handle both string and integer attributes + func getter(_ item: any RuntimeValue) throws -> String { + if let attribute = attribute { + // Handle string attribute + if let strAttr = attribute as? StringValue, + let objectValue = item as? ObjectValue, + let attrValue = objectValue.value[strAttr.value] + { + return caseSensitive ? try stringify(attrValue) : try stringify(attrValue).lowercased() + } + // Handle integer attribute + else if let numAttr = attribute as? NumericValue, + let index = numAttr.value as? Int + { + if let arrayValue = item as? ArrayValue { + guard index >= 0 && index < arrayValue.value.count else { + throw JinjaError.runtime("Index \(index) out of range") + } + let value = arrayValue.value[index] + return caseSensitive ? try stringify(value) : try stringify(value).lowercased() + } else if let stringValue = item as? StringValue { + guard index >= 0 && index < stringValue.value.count else { + throw JinjaError.runtime("Index \(index) out of range") + } + let value = StringValue( + value: String( + stringValue.value[ + stringValue.value.index(stringValue.value.startIndex, offsetBy: index) + ] + ) + ) + return caseSensitive ? try stringify(value) : try stringify(value).lowercased() + } + } + throw JinjaError.runtime("Cannot get attribute '\(try stringify(attribute))' from item") + } + return caseSensitive ? try stringify(item) : try stringify(item).lowercased() + } + var seen: [String: Bool] = [:] + var result: [any RuntimeValue] = [] + // Process all items from the iterable + let items = try getIterableItems(input) + for item in items { + let key = try getter(item) + if seen[key] == nil { + seen[key] = true + result.append(item) + } + } + return ArrayValue(value: result) + }, + "upper": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("upper filter requires a string") + } + return StringValue(value: stringValue.value.uppercased()) + }, + "urlencode": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("urlencode filter requires a string") + } + + let encodedString = stringValue.value.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? "" + return StringValue(value: encodedString) + }, + "urlize": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("urlize filter requires a string") + } + let trimUrlLimit = + (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as? Int : nil + let nofollow = (args.count > 2 && args[2] is BooleanValue) ? (args[2] as! BooleanValue).value : false + let target = (args.count > 3 && args[3] is StringValue) ? (args[3] as! StringValue).value : nil + let urlPattern = + #"(https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|www\.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9]+\.[^\s]{2,}|www\.[a-zA-Z0-9]+\.[^\s]{2,})"# + var urlizedString = stringValue.value + if let regex = try? NSRegularExpression(pattern: urlPattern, options: []) { + let nsRange = NSRange( + stringValue.value.startIndex ..< stringValue.value.endIndex, + in: stringValue.value + ) + let matches = regex.matches(in: stringValue.value, options: [], range: nsRange) + + for match in matches.reversed() { + let urlRange = Range(match.range, in: stringValue.value)! + let url = String(stringValue.value[urlRange]) + var trimmedUrl = url + if let limit = trimUrlLimit, url.count > limit { + trimmedUrl = String(url.prefix(limit)) + "..." + } + var link = " 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 79 + let breakLongWords = (args.count > 2 && args[2] is BooleanValue) ? (args[2] as! BooleanValue).value : true + let wrapString = (args.count > 3 && args[3] is StringValue) ? (args[3] as! StringValue).value : "\n" + var result = "" + var currentLineLength = 0 + for word in stringValue.value.split(separator: " ", omittingEmptySubsequences: false) { + if currentLineLength + word.count > width { + if currentLineLength > 0 { + result += wrapString + currentLineLength = 0 + } + if word.count > width && breakLongWords { + while word.count > width { + result += word.prefix(width) + wrapString + let index = word.index(word.startIndex, offsetBy: width) + let remainder = word[index...] + currentLineLength = remainder.count + } + } + } + if !result.isEmpty { + result += " " + currentLineLength += 1 + } + result += word + currentLineLength += word.count + } + return StringValue(value: result) + }, + "xmlattr": { args, env in + guard let dict = args[0] as? ObjectValue else { + throw JinjaError.runtime("xmlattr filter requires a dictionary") + } + let autospace = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? true : true + var result = "" + for (key, value) in dict.storage { + if !(value is UndefinedValue) && !(value is NullValue) { + if autospace { + result += " " + } + if let stringValue = value as? StringValue { + result += + "\(key)=\"\(stringValue.value.replacingOccurrences(of: "&", with: "&").replacingOccurrences(of: "\"", with: """))\"" + } else { + result += "\(key)=\"\(value)\"" + } + } + } + return StringValue(value: result) + }, + "tojson": { args, env in + guard let firstArg = args.first else { + throw JinjaError.runtime("tojson filter expects at least one argument") + } + var indent: Int? = nil + if args.count > 1, let kwargs = args.last as? ObjectValue, + let indentArg = kwargs.value["indent"] as? NumericValue, + let indentInt = indentArg.value as? Int + { + indent = indentInt + } + return try StringValue(value: toJSON(firstArg, indent: indent, whitespaceControl: false)) }, ] @@ -106,82 +1100,132 @@ class Environment { self.parent = parent } - func isFunction(_ value: Any, functionType: T.Type) -> Bool { - value is T - } + // func isFunction(_ value: Any, functionType: T.Type) -> Bool { + // return value is T + // } + + private func convertToRuntimeValues(input: Any?) throws -> any RuntimeValue { + // Handle already converted RuntimeValue + if let runtimeValue = input as? any RuntimeValue { + return runtimeValue + } + // Handle nil values + if input == nil { + return NullValue() + } + if case Optional.none = input { + return NullValue() + } + // Helper function to handle any OrderedDictionary type + func convertOrderedDictionary(_ dict: OrderedDictionary) throws -> ObjectValue { + var object: [String: any RuntimeValue] = [:] + var keyOrder: [String] = [] - func convertToRuntimeValues(input: Any) throws -> any RuntimeValue { + for (key, value) in dict { + // Crucial: Convert Optional to T, using NullValue if nil + let convertedValue = (value as Any?) ?? NullValue() + object[key] = try self.convertToRuntimeValues(input: convertedValue) + keyOrder.append(key) + } + return ObjectValue(value: object, keyOrder: keyOrder) + } + // Handle other values switch input { case let value as Bool: return BooleanValue(value: value) - case let values as [any Numeric]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) - } - return ArrayValue(value: items) - case let value as any Numeric: + case let value as Int: + return NumericValue(value: value) + case let value as Double: + return NumericValue(value: value) + case let value as Float: return NumericValue(value: value) case let value as String: return StringValue(value: value) + case let data as Data: + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert data to string") + } + return StringValue(value: string) case let fn as (String) throws -> Void: return FunctionValue { args, _ in - var arg = "" - switch args[0].value { - case let value as String: - arg = value - case let value as Bool: - arg = String(value) - default: - throw JinjaError.runtime("Unknown arg type:\(type(of: args[0].value))") + guard let stringArg = args[0] as? StringValue else { + throw JinjaError.runtime("Argument must be a StringValue") } - - try fn(arg) + try fn(stringArg.value) return NullValue() } case let fn as (Bool) throws -> Void: return FunctionValue { args, _ in - try fn(args[0].value as! Bool) + guard let boolArg = args[0] as? BooleanValue else { + throw JinjaError.runtime("Argument must be a BooleanValue") + } + try fn(boolArg.value) return NullValue() } case let fn as (Int, Int?, Int) -> [Int]: return FunctionValue { args, _ in - let result = fn(args[0].value as! Int, args[1].value as? Int, args[2].value as! Int) - return try self.convertToRuntimeValues(input: result) - } - case let values as [Any]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) + guard args.count > 0, let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + var int1: Int? = nil + if args.count > 1 { + if let numericValue = args[1] as? NumericValue, let tempInt1 = numericValue.value as? Int { + int1 = tempInt1 + } else if !(args[1] is NullValue) { // Accept NullValue for optional second argument + throw JinjaError.runtime("Second argument must be an Int or nil") + } + } + var int2: Int = 1 + if args.count > 2 { + if let numericValue = args[2] as? NumericValue, let tempInt2 = numericValue.value as? Int { + int2 = tempInt2 + } else { + throw JinjaError.runtime("Third argument must be an Int") + } + } + let result = fn(int0, int1, int2) + return ArrayValue(value: result.map { NumericValue(value: $0) }) } + case let values as [Any?]: + let items = try values.map { try self.convertToRuntimeValues(input: $0) } return ArrayValue(value: items) - case let dictionary as [String: String]: + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary>: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary>: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let dictionary as [String: Any?]: var object: [String: any RuntimeValue] = [:] - + var keyOrder: [String] = [] for (key, value) in dictionary { - object[key] = StringValue(value: value) + object[key] = try self.convertToRuntimeValues(input: value) + keyOrder.append(key) } - - return ObjectValue(value: object) - case is NullValue: - return NullValue() + return ObjectValue(value: object, keyOrder: keyOrder) default: - throw JinjaError.runtime("Cannot convert to runtime value: \(input) type:\(type(of: input))") + throw JinjaError.runtime( + "Cannot convert to runtime value: \(String(describing: input)) type:\(type(of: input))" + ) } } @discardableResult func set(name: String, value: Any) throws -> any RuntimeValue { - try self.declareVariable(name: name, value: self.convertToRuntimeValues(input: value)) + let runtimeValue = try self.convertToRuntimeValues(input: value) + return try self.declareVariable(name: name, value: runtimeValue) } - func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { - if self.variables.contains(where: { $0.0 == name }) { + private func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { + if self.variables.keys.contains(name) { throw JinjaError.syntax("Variable already declared: \(name)") } self.variables[name] = value - return value } @@ -191,13 +1235,13 @@ class Environment { return value } - func resolve(name: String) throws -> Self { - if self.variables.contains(where: { $0.0 == name }) { + private func resolve(name: String) throws -> Environment { + if self.variables.keys.contains(name) { return self } - if let parent { - return try parent.resolve(name: name) as! Self + if let parent = self.parent { + return try parent.resolve(name: name) } throw JinjaError.runtime("Unknown variable: \(name)") @@ -205,13 +1249,96 @@ class Environment { func lookupVariable(name: String) -> any RuntimeValue { do { - if let value = try self.resolve(name: name).variables[name] { - return value - } else { - return UndefinedValue() - } + return try self.resolve(name: name).variables[name] ?? UndefinedValue() } catch { return UndefinedValue() } } + + // Filters + + private func doDefault(_ args: [any RuntimeValue], _ env: Environment) throws -> any RuntimeValue { + let value = args[0] + let defaultValue = args.count > 1 ? args[1] : StringValue(value: "") + let boolean = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + if value is UndefinedValue || (boolean && !value.bool()) { + return defaultValue + } + return value + } + + private func doEscape(_ args: [any RuntimeValue], _ env: Environment) throws -> any RuntimeValue { + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("escape filter requires a string") + } + return StringValue( + value: stringValue.value.replacingOccurrences(of: "&", with: "&") + .replacingOccurrences(of: "<", with: "<") + .replacingOccurrences(of: ">", with: ">") + .replacingOccurrences(of: "\"", with: """) + .replacingOccurrences(of: "'", with: "'") + ) + } + + private func doEqualTo(_ args: [any RuntimeValue]) -> Bool { + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } + } + + // Tests + + private func doGreaterThan(_ args: [any RuntimeValue]) throws -> Bool { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value > right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return leftInt > rightInt + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return leftDouble > rightDouble + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return Double(leftInt) > rightDouble + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return leftDouble > Double(rightInt) + } + } + throw JinjaError.runtime("Cannot compare values of different types") + } + + private func doGreaterThanOrEqual(_ args: [any RuntimeValue]) throws -> Bool { + return try doGreaterThan(args) || doEqualTo(args) + } + + private func doLessThan(_ args: [any RuntimeValue]) throws -> Bool { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value < right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return leftInt < rightInt + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return leftDouble < rightDouble + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return Double(leftInt) < rightDouble + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return leftDouble < Double(rightInt) + } + } + throw JinjaError.runtime("Cannot compare values of different types") + } + + private func doLessThanOrEqual(_ args: [any RuntimeValue]) throws -> Bool { + return try doLessThan(args) || doEqualTo(args) + } } diff --git a/Sources/Lexer.swift b/Sources/Lexer.swift index 1093960..3c9849d 100644 --- a/Sources/Lexer.swift +++ b/Sources/Lexer.swift @@ -50,6 +50,8 @@ enum TokenType: String { case and = "And" case or = "Or" case not = "Not" + case macro = "Macro" + case endMacro = "EndMacro" } struct Token: Equatable { @@ -70,6 +72,8 @@ let keywords: [String: TokenType] = [ "and": .and, "or": .or, "not": .not, + "macro": .macro, + "endmacro": .endMacro, // Literals "true": .booleanLiteral, "false": .booleanLiteral, @@ -81,7 +85,7 @@ func isWord(char: String) -> Bool { } func isInteger(char: String) -> Bool { - char.range(of: #"[0-9]"#, options: .regularExpression) != nil + char.range(of: #"^[0-9]+$"#, options: .regularExpression) != nil } func isWhile(char: String) -> Bool { @@ -136,21 +140,16 @@ struct PreprocessOptions { func preprocess(template: String, options: PreprocessOptions = PreprocessOptions()) -> String { var template = template - if template.hasSuffix("\n") { template.removeLast() } - template = template.replacing(#/{#.*?#}/#, with: "{##}") - if options.lstripBlocks == true { template = template.replacing(#/(?m)^[ \t]*({[#%])/#, with: { $0.output.1 }) } - if options.trimBlocks == true { template = template.replacing(#/([#%]})\n/#, with: { $0.output.1 }) } - return template .replacing(#/{##}/#, with: "") @@ -163,7 +162,6 @@ func preprocess(template: String, options: PreprocessOptions = PreprocessOptions func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions()) throws -> [Token] { var tokens: [Token] = [] let src = preprocess(template: source, options: options) - var cursorPosition = 0 @discardableResult @@ -175,17 +173,14 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() if cursorPosition >= src.count { throw JinjaError.syntax("Unexpected end of input") } - let escaped = String(src[cursorPosition]) cursorPosition += 1 - guard let unescaped = escapeCharacters[escaped] else { throw JinjaError.syntax("Unexpected escaped character: \(escaped)") } str.append(unescaped) continue } - str.append(String(src[cursorPosition])) cursorPosition += 1 if cursorPosition >= src.count { @@ -197,7 +192,6 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() main: while cursorPosition < src.count { let lastTokenType = tokens.last?.type - if lastTokenType == nil || lastTokenType == .closeStatement || lastTokenType == .closeExpression { var text = "" @@ -213,18 +207,13 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() continue } } - try consumeWhile(predicate: isWhile) - let char = String(src[cursorPosition]) - if char == "-" || char == "+" { let lastTokenType = tokens.last?.type - if lastTokenType == .text || lastTokenType == nil { throw JinjaError.syntax("Unexpected character: \(char)") } - switch lastTokenType { case .identifier, .numericLiteral, @@ -234,18 +223,13 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() .closeParen, .closeSquareBracket: break - default: cursorPosition += 1 - let num = try consumeWhile(predicate: isInteger) - tokens.append(Token(value: "\(char)\(num)", type: num.isEmpty ? .unaryOperator : .numericLiteral)) - continue } } - for (char, token) in orderedMappingTable { let slice = src.slice(start: cursorPosition, end: cursorPosition + char.count) if slice == char { @@ -254,7 +238,6 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() continue main } } - if char == "'" || char == "\"" { cursorPosition += 1 let str = try consumeWhile { str in @@ -264,30 +247,23 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() cursorPosition += 1 continue } - if isInteger(char: char) { let num = try consumeWhile(predicate: isInteger) tokens.append(Token(value: num, type: .numericLiteral)) continue } - if isWord(char: char) { let word = try consumeWhile(predicate: isWord) - let type: TokenType = keywords.contains(where: { $0.key == word }) ? keywords[word]! : .identifier - if type == .in, tokens.last?.type == .not { _ = tokens.popLast() tokens.append(Token(value: "not in", type: .notIn)) } else { tokens.append(Token(value: word, type: type)) } - continue } - throw JinjaError.syntax("Unexpected character: \(char)") } - return tokens } diff --git a/Sources/Parser.swift b/Sources/Parser.swift index 648a025..ad748be 100644 --- a/Sources/Parser.swift +++ b/Sources/Parser.swift @@ -6,6 +6,7 @@ // import Foundation +import OrderedCollections func parse(tokens: [Token]) throws -> Program { var program = Program() @@ -22,40 +23,31 @@ func parse(tokens: [Token]) throws -> Program { return prev } - func parseArgumentsList() throws -> [Statement] { + func parseArgumentsList() throws -> [Expression] { var args: [Expression] = [] - while !typeof(.closeParen) { var argument = try parseExpression() - if typeof(.equals) { - current += 1 - + current += 1 // consume equals if let identifier = argument as? Identifier { let value = try parseExpression() - argument = KeywordArgumentExpression(key: identifier, value: value as! Expression) + argument = KeywordArgumentExpression(key: identifier, value: value) } else { throw JinjaError.syntax("Expected identifier for keyword argument") } } - - args.append(argument as! Expression) - + args.append(argument) if typeof(.comma) { - current += 1 + current += 1 // consume comma } } - return args } - func parseArgs() throws -> [Statement] { + func parseArgs() throws -> [Expression] { try expect(type: .openParen, error: "Expected opening parenthesis for arguments list") - let args = try parseArgumentsList() - try expect(type: .closeParen, error: "Expected closing parenthesis for arguments list") - return args } @@ -63,69 +55,54 @@ func parse(tokens: [Token]) throws -> Program { try StringLiteral(value: expect(type: .text, error: "Expected text token").value) } - func parseCallExpression(callee: Statement) throws -> CallExpression { - var args: [Expression] = [] - - for arg in try parseArgs() { - args.append(arg as! Expression) - } - - var callExpression = CallExpression(callee: callee as! Expression, args: args) - + func parseCallExpression(callee: Expression) throws -> CallExpression { + let args = try parseArgs() + var callExpression = CallExpression(callee: callee, args: args) if typeof(.openParen) { callExpression = try parseCallExpression(callee: callExpression) } - return callExpression } - func parseMemberExpressionArgumentsList() throws -> Statement { - var slices: [Statement?] = [] + func parseMemberExpressionArgumentsList() throws -> Expression { + var slices: [Expression?] = [] var isSlice = false - while !typeof(.closeSquareBracket) { if typeof(.colon) { slices.append(nil) - current += 1 + current += 1 // consume colon isSlice = true } else { - try slices.append(parseExpression()) + slices.append(try parseExpression()) if typeof(.colon) { - current += 1 + current += 1 // consume colon isSlice = true } } } - if slices.isEmpty { throw JinjaError.syntax("Expected at least one argument for member/slice expression") } - if isSlice { if slices.count > 3 { throw JinjaError.syntax("Expected 0-3 arguments for slice expression") } - return SliceExpression( - start: slices[0] as? Expression, - stop: slices.count > 1 ? slices[1] as? Expression : nil, - step: slices.count > 2 ? slices[2] as? Expression : nil + start: slices[0], + stop: slices.count > 1 ? slices[1] : nil, + step: slices.count > 2 ? slices[2] : nil ) } - - return slices[0]! + return slices[0]! // normal member expression } - func parseMemberExpression() throws -> Statement { + func parseMemberExpression() throws -> Expression { var object = try parsePrimaryExpression() - while typeof(.dot) || typeof(.openSquareBracket) { let operation = tokens[current] current += 1 - var property: Statement - + var property: Expression let computed = operation.type != .dot - if computed { property = try parseMemberExpressionArgumentsList() try expect(type: .closeSquareBracket, error: "Expected closing square bracket") @@ -135,52 +112,109 @@ func parse(tokens: [Token]) throws -> Program { throw JinjaError.syntax("Expected identifier following dot operator") } } - object = MemberExpression( - object: object as! Expression, - property: property as! Expression, + object: object, + property: property, computed: computed ) } - return object } - func parseCallMemberExpression() throws -> Statement { + func parseCallMemberExpression() throws -> Expression { let member = try parseMemberExpression() - if typeof(.openParen) { return try parseCallExpression(callee: member) } - return member } - func parseFilterExpression() throws -> Statement { + func parseFilterExpression() throws -> Expression { var operand = try parseCallMemberExpression() - while typeof(.pipe) { - current += 1 - var filter = try parsePrimaryExpression() - if !(filter is Identifier) { - throw JinjaError.syntax("Expected identifier for the test") + current += 1 // consume pipe + guard let filterName = try parsePrimaryExpression() as? Identifier else { + throw JinjaError.syntax("Expected identifier for the filter") } - + var args: [Expression] = [] + var kwargs: [KeywordArgumentExpression] = [] + var dyn_args: Expression? + var dyn_kwargs: Expression? if typeof(.openParen) { - filter = try parseCallExpression(callee: filter) + // Handle filter with arguments + (args, kwargs, dyn_args, dyn_kwargs) = try parseCallArgs() } + operand = FilterExpression( + operand: operand, + filter: filterName, + args: args, + kwargs: kwargs, + dyn_args: dyn_args, + dyn_kwargs: dyn_kwargs + ) + } + return operand + } - if let filter = filter as? Filter { - operand = FilterExpression(operand: operand as! Expression, filter: filter) + func parseCallArgs() throws -> ( + [Expression], [KeywordArgumentExpression], Expression?, Expression? + ) { + try expect(type: .openParen, error: "Expected opening parenthesis for arguments list") + var args: [Expression] = [] + var kwargs: [KeywordArgumentExpression] = [] + var dynArgs: Expression? + var dynKwargs: Expression? + var requireComma = false + while !typeof(.closeParen) { + if requireComma { + try expect(type: .comma, error: "Expected comma between arguments") + if typeof(.closeParen) { + break + } + } + if typeof(.multiplicativeBinaryOperator), tokens[current].value == "*" { + current += 1 // Consume * + if dynArgs != nil || dynKwargs != nil { + throw JinjaError.syntax("Multiple dynamic positional arguments are not allowed.") + } + dynArgs = try parseExpression() + } else if typeof(.multiplicativeBinaryOperator), tokens[current].value == "**" { + current += 1 // Consume ** + if dynKwargs != nil { + throw JinjaError.syntax("Multiple dynamic keyword arguments are not allowed.") + } + dynKwargs = try parseExpression() + } else { + if typeof(.identifier), tokens.count > current + 1, tokens[current + 1].type == .equals { + // Parse keyword argument + guard let key = try parsePrimaryExpression() as? Identifier else { + throw JinjaError.syntax("Expected identifier for keyword argument key") + } + try expect(type: .equals, error: "Expected '=' after keyword argument key") + let value = try parseExpression() + if dynKwargs != nil { + throw JinjaError.syntax("Keyword arguments must be after dynamic keyword arguments") + } + kwargs.append(KeywordArgumentExpression(key: key, value: value)) + } else { + // Parse positional argument + if !kwargs.isEmpty || dynKwargs != nil { + throw JinjaError.syntax("Positional argument after keyword argument") + } + if dynArgs != nil { + throw JinjaError.syntax("Positional arguments must be after dynamic positional arguments") + } + args.append(try parseExpression()) + } } + requireComma = true } - - return operand + try expect(type: .closeParen, error: "Expected closing parenthesis for arguments list") + return (args, kwargs, dynArgs, dynKwargs) } - func parseTestExpression() throws -> Statement { + func parseTestExpression() throws -> Expression { var operand = try parseFilterExpression() - while typeof(.is) { current += 1 let negate = typeof(.not) @@ -194,7 +228,7 @@ func parse(tokens: [Token]) throws -> Program { filter = Identifier(value: "none") } if let test = filter as? Identifier { - operand = TestExpression(operand: operand as! Expression, negate: negate, test: test) + operand = TestExpression(operand: operand, negate: negate, test: test) } else { throw JinjaError.syntax("Expected identifier for the test") } @@ -202,96 +236,116 @@ func parse(tokens: [Token]) throws -> Program { return operand } - func parseMultiplicativeExpression() throws -> Statement { + func parseMultiplicativeExpression() throws -> Expression { var left = try parseTestExpression() - while typeof(.multiplicativeBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseTestExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseAdditiveExpression() throws -> Statement { + func parseAdditiveExpression() throws -> Expression { var left = try parseMultiplicativeExpression() while typeof(.additiveBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseMultiplicativeExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseComparisonExpression() throws -> Statement { + func parseComparisonExpression() throws -> Expression { var left = try parseAdditiveExpression() - while typeof(.comparisonBinaryOperator) || typeof(.in) || typeof(.notIn) { + while typeof(.comparisonBinaryOperator) || typeof(.in) || typeof(.notIn) + || (typeof(.is) + && (tokens.count > current + 1 + && (tokens[current + 1].type == .identifier || tokens[current + 1].type == .not))) + { let operation = tokens[current] current += 1 - let right = try parseAdditiveExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + if operation.type == .is { + if typeof(.not) { + current += 1 + if typeof(.identifier), tokens[current].value == "none" { + current += 1 + left = TestExpression(operand: left, negate: true, test: Identifier(value: "none")) + continue + } else { + throw JinjaError.syntax("Expected 'none' after 'is not'") + } + } else if typeof(.identifier), tokens[current].value == "defined" { + current += 1 + left = TestExpression(operand: left, negate: false, test: Identifier(value: "defined")) + continue + } else { + throw JinjaError.syntax("Expected 'defined' or 'not' after 'is'") + } + } else if operation.type == .notIn { + let right = try parseAdditiveExpression() + left = BinaryExpression(operation: operation, left: left, right: right) + } else { + let right = try parseAdditiveExpression() + left = BinaryExpression(operation: operation, left: left, right: right) + } } - return left } - func parseLogicalNegationExpression() throws -> Statement { - var right: UnaryExpression? - - while typeof(.not) { + func parseLogicalNegationExpression() throws -> Expression { + if typeof(.not) { let operation = tokens[current] current += 1 let argument = try parseLogicalNegationExpression() - right = UnaryExpression(operation: operation, argument: argument as! Expression) - } - - if let right { - return right + return UnaryExpression(operation: operation, argument: argument) } else { return try parseComparisonExpression() } } - func parseLogicalAndExpression() throws -> Statement { + func parseLogicalAndExpression() throws -> Expression { var left = try parseLogicalNegationExpression() while typeof(.and) { let operation = tokens[current] current += 1 let right = try parseLogicalNegationExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } - return left } - func parseLogicalOrExpression() throws -> Statement { + func parseLogicalOrExpression() throws -> Expression { var left = try parseLogicalAndExpression() - while typeof(.or) { - let operation = tokens[current] - current += 1 + current += 1 // Consume 'or' let right = try parseLogicalAndExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: Token(value: "or", type: .or), left: left, right: right) } return left } - func parseTernaryExpression() throws -> Statement { + func parseTernaryExpression() throws -> Expression { let a = try parseLogicalOrExpression() if typeof(.if) { - current += 1 - let test = try parseLogicalOrExpression() - try expect(type: .else, error: "Expected else token") - let b = try parseLogicalOrExpression() - return If(test: test as! Expression, body: [a], alternate: [b]) + current += 1 // consume if token + let predicate = try parseLogicalOrExpression() + if typeof(.else) { + // Ternary expression with else + current += 1 // consume else token + let b = try parseLogicalOrExpression() + return If(test: predicate, body: [a], alternate: [b]) + } else { + // Select expression on iterable + return SelectExpression(iterable: a, test: predicate) + } } - return a } - func parseExpression() throws -> Statement { + func parseExpression() throws -> Expression { try parseTernaryExpression() } @@ -299,66 +353,66 @@ func parse(tokens: [Token]) throws -> Program { guard current + types.count <= tokens.count else { return false } - for (index, type) in types.enumerated() { if type != tokens[current + index].type { return false } } - return true } func parseSetStatement() throws -> Statement { let left = try parseExpression() - if typeof(.equals) { current += 1 - let value = try parseSetStatement() - - return Set(assignee: left as! Expression, value: value as! Expression) + // Parse the right-hand side as an expression + let value = try parseExpression() + // Explicitly cast 'value' to 'Expression' + return Set(assignee: left, value: value) } - return left } func parseIfStatement() throws -> Statement { let test = try parseExpression() - try expect(type: .closeStatement, error: "Expected closing statement token") - var body: [Statement] = [] var alternate: [Statement] = [] - while !(tokens[current].type == .openStatement && (tokens[current + 1].type == .elseIf || tokens[current + 1].type == .else || tokens[current + 1].type == .endIf)) { - try body.append(parseAny()) + body.append(try parseAny()) } if tokens[current].type == .openStatement, tokens[current + 1].type != .endIf { current += 1 if typeof(.elseIf) { try expect(type: .elseIf, error: "Expected elseif token") - try alternate.append(parseIfStatement()) + alternate.append(try parseIfStatement()) } else { try expect(type: .else, error: "Expected else token") try expect(type: .closeStatement, error: "Expected closing statement token") while !(tokens[current].type == .openStatement && tokens[current + 1].type == .endIf) { - try alternate.append(parseAny()) + alternate.append(try parseAny()) } } } - return If(test: test as! Expression, body: body, alternate: alternate) + return If(test: test, body: body, alternate: alternate) } - func parsePrimaryExpression() throws -> Statement { + func parsePrimaryExpression() throws -> Expression { let token = tokens[current] switch token.type { case .numericLiteral: current += 1 - return NumericLiteral(value: Int(token.value) ?? 0) + if let intValue = Int(token.value) { + return NumericLiteral(value: intValue) + } else if let doubleValue = Double(token.value) { + return NumericLiteral(value: doubleValue) + } else { + throw JinjaError.parser("Invalid numeric literal: \(token.value)") + } case .stringLiteral: current += 1 return StringLiteral(value: token.value) @@ -383,7 +437,7 @@ func parse(tokens: [Token]) throws -> Program { current += 1 var values: [Expression] = [] while !typeof(.closeSquareBracket) { - try values.append(parseExpression() as! Expression) + try values.append(parseExpression()) if typeof(.comma) { current += 1 } @@ -392,12 +446,20 @@ func parse(tokens: [Token]) throws -> Program { return ArrayLiteral(value: values) case .openCurlyBracket: current += 1 - var values: [(Expression, Expression)] = [] + var values = OrderedDictionary() while !typeof(.closeCurlyBracket) { let key = try parseExpression() try expect(type: .colon, error: "Expected colon between key and value in object literal") let value = try parseExpression() - values.append((key as! Expression, value as! Expression)) + + if let key = key as? StringLiteral { + values[key.value] = value + } else if let key = key as? Identifier { + values[key.value] = value + } else { + throw JinjaError.syntax("Expected string literal or identifier as key in object literal") + } + if typeof(.comma) { current += 1 } @@ -409,18 +471,18 @@ func parse(tokens: [Token]) throws -> Program { } } - func parseExpressionSequence(primary: Bool = false) throws -> Statement { + func parseExpressionSequence(primary: Bool = false) throws -> Expression { let fn = primary ? parsePrimaryExpression : parseExpression - var expressions: [Expression] = try [fn() as! Expression] + var expressions: [Expression] = try [fn()] let isTuple = typeof(.comma) while isTuple { - current += 1 - try expressions.append(fn() as! Expression) + current += 1 // consume comma + try expressions.append(fn()) if !typeof(.comma) { break } } - + // Return either a tuple or single expression return isTuple ? TupleLiteral(value: expressions) : expressions[0] } @@ -428,7 +490,6 @@ func parse(tokens: [Token]) throws -> Program { guard current + types.count <= tokens.count else { return false } - return types.enumerated().contains { i, type -> Bool in type != tokens[current + i].type } @@ -436,50 +497,80 @@ func parse(tokens: [Token]) throws -> Program { func parseForStatement() throws -> Statement { let loopVariable = try parseExpressionSequence(primary: true) - if !(loopVariable is Identifier || loopVariable is TupleLiteral) { throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + "Expected identifier/tuple for the loop variable, got \(type(of: loopVariable)) instead" ) } - try expect(type: .in, error: "Expected `in` keyword following loop variable") - let iterable = try parseExpression() - + // Handle optional if condition for filtering + var ifCondition: Expression? = nil + if typeof(.if) { + current += 1 // consume if token + ifCondition = try parseExpression() + } try expect(type: .closeStatement, error: "Expected closing statement token") - var body: [Statement] = [] - while not(.openStatement, .endFor) { - try body.append(parseAny()) + var defaultBlock: [Statement] = [] + while not(.openStatement, .endFor) && not(.openStatement, .else) { + body.append(try parseAny()) } + if typeof(.openStatement, .else) { + current += 1 // consume {% + try expect(type: .else, error: "Expected else token") + try expect(type: .closeStatement, error: "Expected closing statement token") - if let loopVariable = loopVariable as? Loopvar { - return For(loopvar: loopVariable, iterable: iterable as! Expression, body: body) + while not(.openStatement, .endFor) { + defaultBlock.append(try parseAny()) + } } - - throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + return For( + loopvar: loopVariable, + iterable: iterable, + body: body, + defaultBlock: defaultBlock, + ifCondition: ifCondition ) } + func parseMacroStatement() throws -> Macro { + let name = try parsePrimaryExpression() + if !(name is Identifier) { + throw JinjaError.syntax("Expected identifier following macro statement") + } + let args = try parseArgs() + try expect(type: .closeStatement, error: "Expected closing statement token") + var body: [Statement] = [] + while not(.openStatement, .endMacro) { + body.append(try parseAny()) + } + return Macro(name: name as! Identifier, args: args, body: body) + } + func parseJinjaStatement() throws -> Statement { + // Consume {% %} tokens try expect(type: .openStatement, error: "Expected opening statement token") var result: Statement - switch tokens[current].type { case .set: - current += 1 + current += 1 // consume 'set' token result = try parseSetStatement() try expect(type: .closeStatement, error: "Expected closing statement token") case .if: - current += 1 + current += 1 // consume 'if' token result = try parseIfStatement() try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endIf, error: "Expected endif token") try expect(type: .closeStatement, error: "Expected %} token") + case .macro: + current += 1 // consume 'macro' token + result = try parseMacroStatement() + try expect(type: .openStatement, error: "Expected {% token") + try expect(type: .endMacro, error: "Expected endmacro token") + try expect(type: .closeStatement, error: "Expected %} token") case .for: - current += 1 + current += 1 // consume 'for' token result = try parseForStatement() try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endFor, error: "Expected endfor token") @@ -487,17 +578,13 @@ func parse(tokens: [Token]) throws -> Program { default: throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)") } - return result } func parseJinjaExpression() throws -> Statement { try expect(type: .openExpression, error: "Expected opening expression token") - let result = try parseExpression() - try expect(type: .closeExpression, error: "Expected closing expression token") - return result } diff --git a/Sources/Runtime.swift b/Sources/Runtime.swift index 73a0d48..4331c20 100644 --- a/Sources/Runtime.swift +++ b/Sources/Runtime.swift @@ -6,11 +6,12 @@ // import Foundation +import OrderedCollections protocol RuntimeValue { - associatedtype T - var value: T { get set } + associatedtype ValueType + var value: ValueType { get } var builtins: [String: any RuntimeValue] { get set } func bool() -> Bool @@ -21,7 +22,12 @@ struct NumericValue: RuntimeValue { var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { - self.value as? Int != 0 + if let intValue = self.value as? Int { + return intValue != 0 + } else if let doubleValue = self.value as? Double { + return doubleValue != 0.0 + } + return false } } @@ -35,7 +41,7 @@ struct BooleanValue: RuntimeValue { } struct NullValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -44,7 +50,7 @@ struct NullValue: RuntimeValue { } struct UndefinedValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -64,56 +70,85 @@ struct ArrayValue: RuntimeValue { } func bool() -> Bool { - !self.value.isEmpty + return !self.value.isEmpty } } struct TupleValue: RuntimeValue { - var value: ArrayValue + var value: [any RuntimeValue] var builtins: [String: any RuntimeValue] = [:] + init(value: [any RuntimeValue]) { + self.value = value + self.builtins["length"] = FunctionValue(value: { _, _ in + NumericValue(value: value.count) + }) + } + func bool() -> Bool { - self.value.bool() + !self.value.isEmpty } } -struct ObjectValue: RuntimeValue { - var value: [String: any RuntimeValue] - var builtins: [String: any RuntimeValue] = [:] +struct ObjectValue: RuntimeValue, Sequence { + var storage: OrderedDictionary + var builtins: [String: any RuntimeValue] - init(value: [String: any RuntimeValue]) { - self.value = value + var value: [String: any RuntimeValue] { Dictionary(uniqueKeysWithValues: storage.map { ($0, $1) }) } + var orderedKeys: [String] { Array(storage.keys) } + + init(value: [String: any RuntimeValue], keyOrder: [String]? = nil) { + // If keyOrder is provided, use it; otherwise, maintain the original order from the dictionary + let orderedKeys = keyOrder ?? Array(value.keys) + let orderedPairs = orderedKeys.compactMap { key in + value[key].map { (key, $0) } + } + + // Recursively create OrderedDictionary for nested objects + let processedPairs = orderedPairs.map { key, value -> (String, any RuntimeValue) in + if let objectValue = value as? ObjectValue { + // Already an ObjectValue, use it directly + return (key, objectValue) + } else if let dictValue = value.value as? [String: any RuntimeValue] { + // If the value contains a dictionary, convert it to ObjectValue + return (key, ObjectValue(value: dictValue)) + } + return (key, value) + } + + self.storage = OrderedDictionary(uniqueKeysWithValues: processedPairs) self.builtins = [ "get": FunctionValue(value: { args, _ in - if let key = args[0] as? StringValue { - if let value = value.first(where: { $0.0 == key.value }) { - return value as! (any RuntimeValue) - } else if args.count > 1 { - return args[1] - } else { - return NullValue() - } - } else { - throw JinjaError.runtime("Object key must be a string: got \(type(of:args[0]))") + guard let key = args[0] as? StringValue else { + throw JinjaError.runtime("Object key must be a string: got \(type(of: args[0]))") } + if let value = value[key.value] { + return value + } else if args.count > 1 { + return args[1] + } + return NullValue() }), "items": FunctionValue(value: { _, _ in - var items: [ArrayValue] = [] - for (k, v) in value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) - } - return items as! (any RuntimeValue) + ArrayValue( + value: orderedPairs.map { key, value in + ArrayValue(value: [StringValue(value: key), value]) + } + ) }), ] } + mutating func setValue(key: String, value: any RuntimeValue) { + storage[key] = value + } + func bool() -> Bool { - !self.value.isEmpty + !storage.isEmpty + } + + func makeIterator() -> OrderedDictionary.Iterator { + return storage.makeIterator() } } @@ -136,22 +171,24 @@ struct StringValue: RuntimeValue { "upper": FunctionValue(value: { _, _ in StringValue(value: value.uppercased()) }), - "lower": FunctionValue(value: { _, _ in StringValue(value: value.lowercased()) }), - "strip": FunctionValue(value: { _, _ in StringValue(value: value.trimmingCharacters(in: .whitespacesAndNewlines)) }), - "title": FunctionValue(value: { _, _ in - StringValue(value: value.capitalized) + StringValue(value: value.titleCase()) }), - "length": FunctionValue(value: { _, _ in NumericValue(value: value.count) }), + "rstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "\\s+$", with: "", options: .regularExpression)) + }), + "lstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "^\\s+", with: "", options: .regularExpression)) + }), ] } @@ -175,23 +212,24 @@ struct Interpreter { var result = "" for statement in statements { let lastEvaluated = try self.evaluate(statement: statement, environment: environment) - if !(lastEvaluated is NullValue), !(lastEvaluated is UndefinedValue) { - if let value = lastEvaluated.value as? String { - result += value + if let stringValue = lastEvaluated as? StringValue { + result += stringValue.value + } else if let numericValue = lastEvaluated as? NumericValue { + result += String(describing: numericValue.value) + } else if let booleanValue = lastEvaluated as? BooleanValue { + result += String(booleanValue.value) + } else if let arrayValue = lastEvaluated as? ArrayValue { + // Convert array to JSON string + result += try toJSON(arrayValue) + } else if let objectValue = lastEvaluated as? ObjectValue { + // Convert object to JSON string + result += try toJSON(objectValue) } else { - switch lastEvaluated.value { - case let value as Int: - result += String(value) - case let value as String: - result += value - default: - throw JinjaError.runtime("Unknown value type:\(type(of: lastEvaluated.value))") - } + throw JinjaError.runtime("Cannot convert to string: \(type(of: lastEvaluated))") } } } - return StringValue(value: result) } @@ -206,229 +244,543 @@ struct Interpreter { try environment.setVariable(name: variableName, value: rhs) } else if let member = node.assignee as? MemberExpression { let object = try self.evaluate(statement: member.object, environment: environment) - - if var object = object as? ObjectValue { - if let property = member.property as? Identifier { - object.value[property.value] = rhs - } else { - throw JinjaError.runtime("Cannot assign to member with non-identifier property") - } - } else { + guard var objectValue = object as? ObjectValue else { throw JinjaError.runtime("Cannot assign to member of non-object") } + guard let property = member.property as? Identifier else { + throw JinjaError.runtime("Cannot assign to member with non-identifier property") + } + // Modify the copy + objectValue.setValue(key: property.value, value: rhs) + // Update the environment with the modified copy + if let parentIdentifier = member.object as? Identifier { + try environment.setVariable(name: parentIdentifier.value, value: objectValue) + } else { + throw JinjaError.runtime("Cannot assign to computed member expression") + } } else { - throw JinjaError.runtime("Invalid assignee type: \(type(of: node.assignee))") + throw JinjaError.runtime("Invalid LHS inside assignment expression: \(node.assignee)") } - return NullValue() } func evaluateIf(node: If, environment: Environment) throws -> StringValue { let test = try self.evaluate(statement: node.test, environment: environment) - return try self.evaluateBlock(statements: test.bool() ? node.body : node.alternate, environment: environment) } func evaluateIdentifier(node: Identifier, environment: Environment) throws -> any RuntimeValue { - environment.lookupVariable(name: node.value) + let value = environment.lookupVariable(name: node.value) + return value } - func evaluateFor(node: For, environment: Environment) throws -> any RuntimeValue { + func evaluateFor(node: For, environment: Environment) throws -> StringValue { + // Scope for the for loop let scope = Environment(parent: environment) - - let iterable = try self.evaluate(statement: node.iterable, environment: scope) - var result = "" - if let iterable = iterable as? ArrayValue { - for i in 0 ..< iterable.value.count { - let loop: [String: any RuntimeValue] = [ - "index": NumericValue(value: i + 1), - "index0": NumericValue(value: i), - "revindex": NumericValue(value: iterable.value.count - i), - "revindex0": NumericValue(value: iterable.value.count - i - 1), - "first": BooleanValue(value: i == 0), - "last": BooleanValue(value: i == iterable.value.count - 1), - "length": NumericValue(value: iterable.value.count), - "previtem": i > 0 ? iterable.value[i - 1] : UndefinedValue(), - "nextitem": i < iterable.value.count - 1 ? iterable.value[i + 1] : UndefinedValue(), - ] - - try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) - - let current = iterable.value[i] - + let test: Expression? + let iterable: any RuntimeValue + if let selectExpression = node.iterable as? SelectExpression { + iterable = try self.evaluate(statement: selectExpression.iterable, environment: scope) + test = selectExpression.test + } else { + iterable = try self.evaluate(statement: node.iterable, environment: scope) + test = nil + } + var items: [any RuntimeValue] = [] + var scopeUpdateFunctions: [(Environment) throws -> Void] = [] + // Keep track of the indices of the original iterable that passed the test + var filteredIndices: [Int] = [] + var originalIndex = 0 + // Handle ArrayValue + if let arrayIterable = iterable as? ArrayValue { + for current in arrayIterable.value { + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void if let identifier = node.loopvar as? Identifier { - try scope.setVariable(name: identifier.value, value: current) - } else { - } - - switch node.loopvar { - case let identifier as Identifier: - try scope.setVariable(name: identifier.value, value: current) - case let tupleLiteral as TupleLiteral: - if let current = current as? ArrayValue { - if tupleLiteral.value.count != current.value.count { - throw JinjaError.runtime( - "Too \(tupleLiteral.value.count > current.value.count ? "few" : "many") items to unpack" - ) - } - - for j in 0 ..< tupleLiteral.value.count { - if let identifier = tupleLiteral.value[j] as? Identifier { - try scope.setVariable(name: identifier.value, value: current.value[j]) - } else { - throw JinjaError.runtime( - "Cannot unpack non-identifier type: \(type(of:tupleLiteral.value[j]))" - ) + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + guard let currentArray = current as? ArrayValue else { + throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of: current))") + } + if tupleLiteral.value.count != currentArray.value.count { + throw JinjaError.runtime( + "Too \(tupleLiteral.value.count > currentArray.value.count ? "few" : "many") items to unpack" + ) + } + scopeUpdateFunction = { scope in + for (i, value) in tupleLiteral.value.enumerated() { + guard let identifier = value as? Identifier else { + throw JinjaError.runtime("Cannot unpack non-identifier type: \(type(of: value))") } + try scope.setVariable(name: identifier.value, value: currentArray.value[i]) } - } else { - throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))") } - default: - throw JinjaError.syntaxNotSupported(String(describing: node.loopvar)) + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") } - - let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) - result += evaluated.value + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 + } + // Handle StringValue as a special case + } else if let stringIterable = iterable as? StringValue { + // Treat the string as an iterable of characters + for char in stringIterable.value { + let current = StringValue(value: String(char)) + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 + } + // Handle ObjectValue (dictionary) + } else if let objectIterable = iterable as? ObjectValue { + // Treat the dictionary as an iterable of key-value pairs + for (key, value) in objectIterable { + let current = ArrayValue(value: [StringValue(value: key), value]) + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + // Support unpacking of key-value pairs into two variables + if tupleLiteral.value.count != 2 { + throw JinjaError.runtime( + "Cannot unpack dictionary entry: expected 2 variables, got \(tupleLiteral.value.count)" + ) + } + guard let keyIdentifier = tupleLiteral.value[0] as? Identifier else { + throw JinjaError.runtime( + "Cannot unpack dictionary entry into non-identifier: \(type(of: tupleLiteral.value[0]))" + ) + } + guard let valueIdentifier = tupleLiteral.value[1] as? Identifier else { + throw JinjaError.runtime( + "Cannot unpack dictionary entry into non-identifier: \(type(of: tupleLiteral.value[1]))" + ) + } + scopeUpdateFunction = { scope in + try scope.setVariable(name: keyIdentifier.value, value: StringValue(value: key)) + try scope.setVariable(name: valueIdentifier.value, value: value) + } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 } } else { - throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of:iterable))") + throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of: iterable))") + } + var result = "" + var noIteration = true + for i in 0 ..< items.count { + // Get the previous and next items that passed the filter + let previousIndex = filteredIndices.firstIndex(of: filteredIndices[i])! - 1 + let nextIndex = filteredIndices.firstIndex(of: filteredIndices[i])! + 1 + let previtem: any RuntimeValue + if previousIndex >= 0 { + let previousFilteredIndex = filteredIndices[previousIndex] + if let arrayIterable = iterable as? ArrayValue { + previtem = arrayIterable.value[previousFilteredIndex] + } else if let stringIterable = iterable as? StringValue { + let index = stringIterable.value.index( + stringIterable.value.startIndex, + offsetBy: previousFilteredIndex + ) + previtem = StringValue(value: String(stringIterable.value[index])) + } else if let objectIterable = iterable as? ObjectValue { + let (key, value) = objectIterable.storage.elements[previousFilteredIndex] + previtem = ArrayValue(value: [StringValue(value: key), value]) + } else { + previtem = UndefinedValue() + } + } else { + previtem = UndefinedValue() + } + let nextitem: any RuntimeValue + if nextIndex < filteredIndices.count { + let nextFilteredIndex = filteredIndices[nextIndex] + if let arrayIterable = iterable as? ArrayValue { + nextitem = arrayIterable.value[nextFilteredIndex] + } else if let stringIterable = iterable as? StringValue { + let index = stringIterable.value.index(stringIterable.value.startIndex, offsetBy: nextFilteredIndex) + nextitem = StringValue(value: String(stringIterable.value[index])) + } else if let objectIterable = iterable as? ObjectValue { + let (key, value) = objectIterable.storage.elements[nextFilteredIndex] + nextitem = ArrayValue(value: [StringValue(value: key), value]) + } else { + nextitem = UndefinedValue() + } + } else { + nextitem = UndefinedValue() + } + let loop: [String: any RuntimeValue] = [ + "index": NumericValue(value: i + 1), + "index0": NumericValue(value: i), + "revindex": NumericValue(value: items.count - i), + "revindex0": NumericValue(value: items.count - i - 1), + "first": BooleanValue(value: i == 0), + "last": BooleanValue(value: i == items.count - 1), + "length": NumericValue(value: items.count), + "previtem": previtem, + "nextitem": nextitem, + ] + try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) + try scopeUpdateFunctions[i](scope) + let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) + result += evaluated.value + noIteration = false + } + if noIteration { + let defaultEvaluated = try self.evaluateBlock(statements: node.defaultBlock, environment: scope) + result += defaultEvaluated.value } - return StringValue(value: result) } func evaluateBinaryExpression(node: BinaryExpression, environment: Environment) throws -> any RuntimeValue { let left = try self.evaluate(statement: node.left, environment: environment) - + let right = try self.evaluate(statement: node.right, environment: environment) + // Handle 'or' + if node.operation.value == "or" { + if left.bool() { + return left + } else { + return right + } + } + // Handle 'and' if node.operation.value == "and" { - return left.bool() ? try self.evaluate(statement: node.right, environment: environment) : left - } else if node.operation.value == "or" { - return left.bool() ? left : try self.evaluate(statement: node.right, environment: environment) + if !left.bool() { + return left + } else { + return right + } } - - let right = try self.evaluate(statement: node.right, environment: environment) - + // == if node.operation.value == "==" { - switch left.value { - case let value as String: - return BooleanValue(value: value == right.value as! String) - case let value as Int: - return BooleanValue(value: value == right.value as! Int) - case let value as Bool: - return BooleanValue(value: value == right.value as! Bool) - default: - throw JinjaError.runtime( - "Unknown left value type:\(type(of: left.value)), right value type:\(type(of: right.value))" - ) + // Handle array indexing for right operand + if let memberExpr = node.right as? MemberExpression, + let arrayValue = try self.evaluate(statement: memberExpr.object, environment: environment) + as? ArrayValue, + let indexExpr = memberExpr.property as? NumericLiteral, + let index = indexExpr.value as? Int + { + + // Handle negative indices + let actualIndex = index < 0 ? arrayValue.value.count + index : index + if actualIndex >= 0 && actualIndex < arrayValue.value.count { + let rightValue = arrayValue.value[actualIndex] + return BooleanValue(value: try areEqual(left, rightValue)) + } } - } else if node.operation.value == "!=" { - if type(of: left) != type(of: right) { + + return BooleanValue(value: try areEqual(left, right)) + } + // != + if node.operation.value == "!=" { + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value != right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt != rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble != rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) != rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble != Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for inequality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value != right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: false) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: false) + } else if type(of: left) == type(of: right) { return BooleanValue(value: true) } else { - return BooleanValue(value: left.value as! AnyHashable != right.value as! AnyHashable) + return BooleanValue(value: true) } } - if left is UndefinedValue || right is UndefinedValue { throw JinjaError.runtime("Cannot perform operation on undefined values") } else if left is NullValue || right is NullValue { throw JinjaError.runtime("Cannot perform operation on null values") } else if let left = left as? NumericValue, let right = right as? NumericValue { switch node.operation.value { - case "+": throw JinjaError.syntaxNotSupported("+") - case "-": throw JinjaError.syntaxNotSupported("-") - case "*": throw JinjaError.syntaxNotSupported("*") - case "/": throw JinjaError.syntaxNotSupported("/") + case "+": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt + rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble + rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) + rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble + Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for addition") + } + case "-": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt - rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble - rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) - rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble - Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for subtraction") + } + case "*": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt * rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble * rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) * rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble * Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for multiplication") + } + case "/": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt / rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble / rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) / rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble / Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for division") + } case "%": - switch left.value { - case is Int: - return NumericValue(value: left.value as! Int % (right.value as! Int)) - default: - throw JinjaError.runtime("Unknown value type:\(type(of: left.value))") - } - case "<": throw JinjaError.syntaxNotSupported("<") - case ">": throw JinjaError.syntaxNotSupported(">") - case ">=": throw JinjaError.syntaxNotSupported(">=") - case "<=": throw JinjaError.syntaxNotSupported("<=") + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt % rightInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for modulus") + } + case "<": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt < rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble < rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) < rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble < Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than comparison") + } + case ">": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt > rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble > rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) > rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble > Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than comparison") + } + case ">=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt >= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble >= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) >= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble >= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than or equal to comparison") + } + case "<=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt <= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble <= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) <= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble <= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than or equal to comparison") + } default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if left is ArrayValue && right is ArrayValue { + } else if let left = left as? ArrayValue, let right = right as? ArrayValue { switch node.operation.value { - case "+": break + case "+": + return ArrayValue(value: left.value + right.value) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if right is ArrayValue { - throw JinjaError.syntaxNotSupported("right is ArrayValue") - } - - if left is StringValue || right is StringValue { - switch node.operation.value { - case "+": - var rightValue = "" - var leftValue = "" - switch right.value { - case let value as String: - rightValue = value - case let value as Int: - rightValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown right value type:\(type(of: right.value))") - } - - switch left.value { - case let value as String: - leftValue = value - case let value as Int: - leftValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown left value type:\(type(of: left.value))") - } - - return StringValue(value: leftValue + rightValue) - default: - break + } else if let right = right as? ArrayValue { + let member: Bool + if let left = left as? StringValue { + member = right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + } else if let left = left as? NumericValue { + member = right.value.contains { + if let item = $0 as? NumericValue { + return item.value as! Int == left.value as! Int + } + return false + } + } else if let left = left as? BooleanValue { + member = right.value.contains { + if let item = $0 as? BooleanValue { + return item.value == left.value + } + return false + } + } else { + throw JinjaError.runtime("Unsupported left type for 'in'/'not in' operation with ArrayValue") } - } - - if let left = left as? StringValue, let right = right as? StringValue { switch node.operation.value { case "in": - return BooleanValue(value: right.value.contains(left.value)) + return BooleanValue(value: member) case "not in": - return BooleanValue(value: !right.value.contains(left.value)) + return BooleanValue(value: !member) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } } - - if left is StringValue, right is ObjectValue { + if let left = left as? StringValue { switch node.operation.value { + case "+": + let rightValue: String + if let rightString = right as? StringValue { + rightValue = rightString.value + } else if let rightNumeric = right as? NumericValue { + rightValue = String(describing: rightNumeric.value) + } else if let rightBoolean = right as? BooleanValue { + rightValue = String(rightBoolean.value) + } else if right is UndefinedValue { + rightValue = "" + } else { + throw JinjaError.runtime("Unsupported right operand type for string concatenation") + } + return StringValue(value: left.value + rightValue) case "in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime("Right operand of 'in' must be a StringValue, ArrayValue, or ObjectValue") } case "not in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: !rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: !right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: !right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: !right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime( + "Right operand of 'not in' must be a StringValue, ArrayValue, or ObjectValue" + ) } + default: + break + } + } else if let right = right as? StringValue { + if node.operation.value == "+" { + if let leftString = left as? StringValue { + return StringValue(value: leftString.value + right.value) + } else if let leftNumeric = left as? NumericValue { + return StringValue(value: String(describing: leftNumeric.value) + right.value) + } else if let leftBoolean = left as? BooleanValue { + return StringValue(value: String(leftBoolean.value) + right.value) + } else { + throw JinjaError.runtime("Unsupported left operand type for string concatenation") + } + } + } + if let left = left as? StringValue, let right = right as? ObjectValue { + switch node.operation.value { + case "in": + return BooleanValue(value: right.value.keys.contains(left.value)) + case "not in": + return BooleanValue(value: !right.value.keys.contains(left.value)) default: throw JinjaError.runtime( "Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue" ) } } - throw JinjaError.syntax( "Unknown operator '\(node.operation.value)' between \(type(of:left)) and \(type(of:right))" ) @@ -442,49 +794,42 @@ struct Interpreter { if !(object is ArrayValue || object is StringValue) { throw JinjaError.runtime("Slice object must be an array or string") } - let start = try self.evaluate(statement: expr.start, environment: environment) let stop = try self.evaluate(statement: expr.stop, environment: environment) let step = try self.evaluate(statement: expr.step, environment: environment) - if !(start is NumericValue || start is UndefinedValue) { throw JinjaError.runtime("Slice start must be numeric or undefined") } - if !(stop is NumericValue || stop is UndefinedValue) { throw JinjaError.runtime("Slice stop must be numeric or undefined") } - if !(step is NumericValue || step is UndefinedValue) { throw JinjaError.runtime("Slice step must be numeric or undefined") } - if let object = object as? ArrayValue { return ArrayValue( value: slice( object.value, - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int ) ) } else if let object = object as? StringValue { return StringValue( value: slice( - Array(arrayLiteral: object.value), - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int - ).joined() + Array(object.value), + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int + ).map { String($0) }.joined() ) } - throw JinjaError.runtime("Slice object must be an array or string") } func evaluateMemberExpression(expr: MemberExpression, environment: Environment) throws -> any RuntimeValue { let object = try self.evaluate(statement: expr.object, environment: environment) - var property: any RuntimeValue if expr.computed { if let property = expr.property as? SliceExpression { @@ -495,7 +840,6 @@ struct Interpreter { } else { property = StringValue(value: (expr.property as! Identifier).value) } - var value: (any RuntimeValue)? if let object = object as? ObjectValue { if let property = property as? StringValue { @@ -503,34 +847,55 @@ struct Interpreter { } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } - } else if object is ArrayValue || object is StringValue { + } else if let object = object as? ArrayValue { if let property = property as? NumericValue { - if let object = object as? ArrayValue { - let index = property.value as! Int - if index >= 0 { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { value = object.value[index] - } else { + } else if index < 0 && index >= -object.value.count { value = object.value[object.value.count + index] + } else { + value = UndefinedValue() + } + } else { + throw JinjaError.runtime("Array index must be an integer") + } + } else if let property = property as? StringValue { + value = object.builtins[property.value] + } else { + throw JinjaError.runtime( + "Cannot access property with non-string/non-number: got \(type(of: property))" + ) + } + } else if let object = object as? StringValue { + if let property = property as? NumericValue { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: index) + value = StringValue(value: String(object.value[strIndex])) + } else if index < 0 && index >= -object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: object.value.count + index) + value = StringValue(value: String(object.value[strIndex])) + } else { + value = UndefinedValue() } - } else if let object = object as? StringValue { - let index = object.value.index(object.value.startIndex, offsetBy: property.value as! Int) - value = StringValue(value: String(object.value[index])) + } else { + throw JinjaError.runtime("String index must be an integer") } } else if let property = property as? StringValue { value = object.builtins[property.value] } else { throw JinjaError.runtime( - "Cannot access property with non-string/non-number: got \(type(of:property))" + "Cannot access property with non-string/non-number: got \(type(of: property))" ) } } else { if let property = property as? StringValue { - value = object.builtins[property.value]! + value = object.builtins[property.value] } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } } - if let value { return value } else { @@ -540,7 +905,6 @@ struct Interpreter { func evaluateUnaryExpression(node: UnaryExpression, environment: Environment) throws -> any RuntimeValue { let argument = try self.evaluate(statement: node.argument, environment: environment) - switch node.operation.value { case "not": return BooleanValue(value: !argument.bool()) @@ -552,7 +916,6 @@ struct Interpreter { func evaluateCallExpression(expr: CallExpression, environment: Environment) throws -> any RuntimeValue { var args: [any RuntimeValue] = [] var kwargs: [String: any RuntimeValue] = [:] - for argument in expr.args { if let argument = argument as? KeywordArgumentExpression { kwargs[argument.key.value] = try self.evaluate(statement: argument.value, environment: environment) @@ -560,13 +923,10 @@ struct Interpreter { try args.append(self.evaluate(statement: argument, environment: environment)) } } - - if kwargs.count > 0 { + if !kwargs.isEmpty { args.append(ObjectValue(value: kwargs)) } - let fn = try self.evaluate(statement: expr.callee, environment: environment) - if let fn = fn as? FunctionValue { return try fn.value(args, environment) } else { @@ -574,89 +934,108 @@ struct Interpreter { } } - func evaluateFilterExpression(node: FilterExpression, environment: Environment) throws -> any RuntimeValue { - let operand = try evaluate(statement: node.operand, environment: environment) - - if let identifier = node.filter as? Identifier { - if let arrayValue = operand as? ArrayValue { - switch identifier.value { - case "list": - return arrayValue - case "first": - return arrayValue.value.first ?? UndefinedValue() - case "last": - return arrayValue.value.last ?? UndefinedValue() - case "length": - return NumericValue(value: arrayValue.value.count) - case "reverse": - return ArrayValue(value: arrayValue.value.reversed()) - case "sort": - throw JinjaError.todo("TODO: ArrayValue filter sort") - default: - throw JinjaError.runtime("Unknown ArrayValue filter: \(identifier.value)") - } - } else if let stringValue = operand as? StringValue { - switch identifier.value { - case "length": - return NumericValue(value: stringValue.value.count) - case "upper": - return StringValue(value: stringValue.value.uppercased()) - case "lower": - return StringValue(value: stringValue.value.lowercased()) - case "title": - return StringValue(value: stringValue.value.capitalized) - case "capitalize": - return StringValue(value: stringValue.value.capitalized) - case "trim": - return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) - default: - throw JinjaError.runtime("Unknown StringValue filter: \(identifier.value)") - } - } else if let numericValue = operand as? NumericValue { - switch identifier.value { - case "abs": - return NumericValue(value: abs(numericValue.value as! Int32)) - default: - throw JinjaError.runtime("Unknown NumericValue filter: \(identifier.value)") - } - } else if let objectValue = operand as? ObjectValue { - switch identifier.value { - case "items": - var items: [ArrayValue] = [] - for (k, v) in objectValue.value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) - } - return items as! (any RuntimeValue) - case "length": - return NumericValue(value: objectValue.value.count) - default: - throw JinjaError.runtime("Unknown ObjectValue filter: \(identifier.value)") - } + func evaluateFilterExpression(node: FilterExpression, environment: Environment, whitespaceControl: Bool) throws + -> any RuntimeValue + { + let operand = try self.evaluate(statement: node.operand, environment: environment) + let filterName = node.filter.value + guard let filter = environment.filters[filterName] else { + throw JinjaError.runtime("No filter named '\(filterName)'") + } + // Evaluate positional arguments + let evaluatedPositionalArgs = try node.args.map { arg in + try self.evaluate(statement: arg, environment: environment) + } + // Create args array starting with operand + var args: [any RuntimeValue] = [operand] + args.append(contentsOf: evaluatedPositionalArgs) + // If we have keyword arguments, add them as a final ObjectValue argument + if !node.kwargs.isEmpty { + var kwargs: [String: any RuntimeValue] = [:] + for kwarg in node.kwargs { + kwargs[kwarg.key.value] = try self.evaluate(statement: kwarg.value, environment: environment) } - - throw JinjaError.runtime("Cannot apply filter \(operand.value) to type: \(type(of:operand))") + args.append(ObjectValue(value: kwargs)) } - - throw JinjaError.runtime("Unknown filter: \(node.filter)") + return try filter(args, environment) } func evaluateTestExpression(node: TestExpression, environment: Environment) throws -> any RuntimeValue { let operand = try self.evaluate(statement: node.operand, environment: environment) - - if let testFunction = environment.tests[node.test.value] { - let result = try testFunction(operand) - return BooleanValue(value: node.negate ? !result : result) - } else { + guard let testFunction = environment.tests[node.test.value] else { throw JinjaError.runtime("Unknown test: \(node.test.value)") } + let result = try testFunction(operand) + return BooleanValue(value: node.negate ? !result : result) + } + + func evaluateMacro(node: Macro, environment: Environment) throws -> NullValue { + try environment.setVariable( + name: node.name.value, + value: FunctionValue(value: { args, scope in + let macroScope = Environment(parent: scope) + var args = args + var kwargs: [String: any RuntimeValue] = [:] + if let lastArg = args.last, let keywordArgsValue = lastArg as? KeywordArgumentsValue { + kwargs = keywordArgsValue.value + args.removeLast() + } + for i in 0 ..< node.args.count { + let nodeArg = node.args[i] + let passedArg = args.count > i ? args[i] : nil + + if let identifier = nodeArg as? Identifier { + if passedArg == nil { + if let defaultValue = kwargs[identifier.value] { + try macroScope.setVariable(name: identifier.value, value: defaultValue) + } else { + throw JinjaError.runtime("Missing argument: \(identifier.value)") + } + } else { + try macroScope.setVariable(name: identifier.value, value: passedArg!) + } + } else if let kwarg = nodeArg as? KeywordArgumentExpression { + let value = + try kwargs[kwarg.key.value] + ?? (passedArg ?? (try self.evaluate(statement: kwarg.value, environment: macroScope))) + + try macroScope.setVariable(name: kwarg.key.value, value: value) + } else { + throw JinjaError.runtime("Unknown argument type: \(type(of: nodeArg))") + } + } + return try self.evaluateBlock(statements: node.body, environment: macroScope) + }) + ) + return NullValue() } - func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue { + func evaluateArguments( + args: [Expression], + environment: Environment + ) throws -> ([any RuntimeValue], [String: any RuntimeValue]) { + var positionalArguments: [any RuntimeValue] = [] + var keywordArguments: [String: any RuntimeValue] = [:] + for argument in args { + if let keywordArgument = argument as? KeywordArgumentExpression { + keywordArguments[keywordArgument.key.value] = try self.evaluate( + statement: keywordArgument.value, + environment: environment + ) + } else { + if !keywordArguments.isEmpty { + throw JinjaError.runtime("Positional arguments must come before keyword arguments") + } + positionalArguments.append(try self.evaluate(statement: argument, environment: environment)) + } + } + + return (positionalArguments, keywordArguments) + } + + func evaluate(statement: Statement?, environment: Environment, whitespaceControl: Bool = false) throws + -> any RuntimeValue + { if let statement { switch statement { case let statement as Program: @@ -678,15 +1057,41 @@ struct Interpreter { case let statement as UnaryExpression: return try self.evaluateUnaryExpression(node: statement, environment: environment) case let statement as NumericLiteral: - return NumericValue(value: statement.value) + if let intValue = statement.value as? Int { + return NumericValue(value: intValue) + } else if let doubleValue = statement.value as? Double { + return NumericValue(value: doubleValue) + } else { + throw JinjaError.runtime("Invalid numeric literal value") + } case let statement as CallExpression: return try self.evaluateCallExpression(expr: statement, environment: environment) case let statement as BoolLiteral: return BooleanValue(value: statement.value) case let statement as FilterExpression: - return try self.evaluateFilterExpression(node: statement, environment: environment) + return try self.evaluateFilterExpression( + node: statement, + environment: environment, + whitespaceControl: whitespaceControl + ) case let statement as TestExpression: return try self.evaluateTestExpression(node: statement, environment: environment) + case let statement as ArrayLiteral: + return ArrayValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as TupleLiteral: + return TupleValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as ObjectLiteral: + var mapping: [String: any RuntimeValue] = [:] + for (key, value) in statement.value { + mapping[key] = try self.evaluate(statement: value, environment: environment) + } + return ObjectValue(value: mapping) + case let statement as Macro: + return try self.evaluateMacro(node: statement, environment: environment) case is NullLiteral: return NullValue() default: diff --git a/Sources/Utilities.swift b/Sources/Utilities.swift index c01870b..7017acb 100644 --- a/Sources/Utilities.swift +++ b/Sources/Utilities.swift @@ -21,7 +21,6 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) let stopValue = stop ?? arrayCount let step = step ?? 1 var slicedArray = [T]() - if step > 0 { let startIndex = startValue < 0 ? max(arrayCount + startValue, 0) : min(startValue, arrayCount) let stopIndex = stopValue < 0 ? max(arrayCount + stopValue, 0) : min(stopValue, arrayCount) @@ -35,6 +34,197 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) slicedArray.append(array[i]) } } - return slicedArray } + +func toJSON(_ input: any RuntimeValue, indent: Int? = nil, depth: Int = 0, whitespaceControl: Bool = false) throws + -> String +{ + // If whitespaceControl is true, output compact JSON + if whitespaceControl { + switch input { + case is NullValue, is UndefinedValue: + return "null" + case let value as NumericValue: + return String(describing: value.value) + case let value as StringValue: + let escapedValue = value.value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") + return "\"\(escapedValue)\"" + case let value as BooleanValue: + return value.value ? "true" : "false" + case let arr as ArrayValue: + let elements = try arr.value.map { + try toJSON($0, indent: nil, depth: 0, whitespaceControl: true) + } + return "[\(elements.joined(separator: ", "))]" + case let obj as ObjectValue: + let pairs = try obj.orderedKeys.map { key in + guard let value = obj.value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try toJSON(value, indent: nil, depth: 0, whitespaceControl: true) + return "\"\(key)\": \(jsonValue)" + } + return "{\(pairs.joined(separator: ", "))}" + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } + } + let currentDepth = depth + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: currentDepth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + switch input { + case is NullValue, is UndefinedValue: + return "null" + case let value as NumericValue: + return String(describing: value.value) + case let value as StringValue: + // Properly escape special characters for JSON strings + let escapedValue = value.value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") + return "\"\(escapedValue)\"" + case let value as BooleanValue: + return value.value ? "true" : "false" + case let arr as ArrayValue: + let core = try arr.value.map { + try toJSON($0, indent: indent, depth: currentDepth + 1, whitespaceControl: whitespaceControl) + } + if indent != nil && !whitespaceControl { + return "[\(childrenPadding)\(core.joined(separator: ",\(childrenPadding)"))\(basePadding)]" + } else { + return "[\(core.joined(separator: ", "))]" + } + case let obj as ObjectValue: + // Use orderedKeys to maintain insertion order + let pairs = try obj.orderedKeys.map { key in + guard let value = obj.value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try toJSON( + value, + indent: indent, + depth: currentDepth + 1, + whitespaceControl: whitespaceControl + ) + return "\"\(key)\": \(jsonValue)" + } + if indent != nil && !whitespaceControl { + return "{\(childrenPadding)\(pairs.joined(separator: ",\(childrenPadding)"))\(basePadding)}" + } else { + return "{\(pairs.joined(separator: ", "))}" + } + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } +} + +// Helper function to convert values to JSON strings +func jsonString(_ value: Any) throws -> String { + let data = try JSONSerialization.data(withJSONObject: value) + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert value to JSON string") + } + return string +} + +extension String { + func titleCase() -> String { + return self.components(separatedBy: .whitespacesAndNewlines) + .map { word in + guard let firstChar = word.first else { return "" } + return String(firstChar).uppercased() + word.dropFirst() + } + .joined(separator: " ") + } + + func indent(_ width: Int, first: Bool = false, blank: Bool = false) -> String { + let indentString = String(repeating: " ", count: width) + return self.components(separatedBy: .newlines) + .enumerated() + .map { (index, line) in + if line.isEmpty && !blank { + return line + } + if index == 0 && !first { + return line + } + return indentString + line + } + .joined(separator: "\n") + } +} + +func stringify(_ value: any RuntimeValue, indent: Int = 4, whitespaceControl: Bool = false) throws -> String { + if let stringValue = value as? StringValue { + return "\"\(stringValue.value)\"" + } else if let numericValue = value as? NumericValue { + return String(describing: numericValue.value) + } else if let booleanValue = value as? BooleanValue { + return booleanValue.value ? "true" : "false" + } else if let objectValue = value as? ObjectValue { + return try toJSON(objectValue, indent: indent, whitespaceControl: whitespaceControl) + } else if let arrayValue = value as? ArrayValue { + return try toJSON(arrayValue, indent: indent, whitespaceControl: whitespaceControl) + } else if value is NullValue { + return "null" + } else if value is UndefinedValue { + return "undefined" + } else { + return "" + } +} + +func areEqual(_ left: any RuntimeValue, _ right: any RuntimeValue) throws -> Bool { + if let leftObj = left as? ObjectValue, let rightObj = right as? ObjectValue { + // Compare ObjectValues by their contents + guard leftObj.storage.keys == rightObj.storage.keys else { + return false + } + + for key in leftObj.storage.keys { + guard let leftValue = leftObj.storage[key], + let rightValue = rightObj.storage[key], + try areEqual(leftValue, rightValue) + else { + return false + } + } + return true + } else if let leftStr = left as? StringValue, let rightStr = right as? StringValue { + return leftStr.value == rightStr.value + } else if let leftNum = left as? NumericValue, let rightNum = right as? NumericValue { + if let leftInt = leftNum.value as? Int, let rightInt = rightNum.value as? Int { + return leftInt == rightInt + } else if let leftDouble = leftNum.value as? Double, let rightDouble = rightNum.value as? Double { + return leftDouble == rightDouble + } + } else if let leftArr = left as? ArrayValue, let rightArr = right as? ArrayValue { + guard leftArr.value.count == rightArr.value.count else { + return false + } + for (leftItem, rightItem) in zip(leftArr.value, rightArr.value) { + guard try areEqual(leftItem, rightItem) else { + return false + } + } + return true + } else if left is NullValue && right is NullValue { + return true + } else if left is UndefinedValue && right is UndefinedValue { + return true + } else if let leftBool = left as? BooleanValue, let rightBool = right as? BooleanValue { + return leftBool.value == rightBool.value + } + // If types don't match, return false + return false +} diff --git a/Tests/ChatTemplateTests.swift b/Tests/ChatTemplateTests.swift deleted file mode 100644 index 4b9ab6b..0000000 --- a/Tests/ChatTemplateTests.swift +++ /dev/null @@ -1,238 +0,0 @@ -// -// ChatTemplateTests.swift -// -// -// Created by John Mai on 2024/3/24. -// - -import XCTest - -@testable import Jinja - -let messages: [[String: String]] = [ - [ - "role": "user", - "content": "Hello, how are you?", - ], - [ - "role": "assistant", - "content": "I'm doing great. How can I help you today?", - ], - [ - "role": "user", - "content": "I'd like to show off how chat templating works!", - ], -] - -let messagesWithSystem: [[String: String]] = - [ - [ - "role": "system", - "content": "You are a friendly chatbot who always responds in the style of a pirate", - ] - ] + messages - -final class ChatTemplateTests: XCTestCase { - struct Test { - let chatTemplate: String - let data: [String: Any] - let target: String - } - - let defaultTemplates: [Test] = [ - Test( - chatTemplate: - "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messages, - "add_generation_prompt": false, - ], - target: - "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" - ), - // facebook/blenderbot-400M-distill - Test( - chatTemplate: - "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" - ), - // facebook/blenderbot_small-90M - Test( - chatTemplate: - "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" - ), - // bigscience/bloom - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!" - ), - // EleutherAI/gpt-neox-20b - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // gpt2 - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": messagesWithSystem, - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": messages, - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nDEFAULT_SYSTEM_MESSAGE\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": [ - [ - "role": "user", - "content": "<>\nYou are a helpful assistant\n<> Hello, how are you?", - ], - [ - "role": "assistant", - "content": "I'm doing great. How can I help you today?", - ], - [ - "role": "user", - "content": "I'd like to show off how chat templating works!", - ], - ], - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nYou are a helpful assistant\n<> Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // openai/whisper-large-v3 - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messages, - "add_generation_prompt": true, - ], - target: - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messagesWithSystem, - "add_generation_prompt": true, - ], - target: - "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messagesWithSystem - ], - target: - "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" - ), - // THUDM/chatglm3-6b - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - data: [ - "messages": messagesWithSystem - ], - target: - "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" - ), - // google/gemma-2b-it - Test( - chatTemplate: - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - data: [ - "messages": messages - ], - target: - "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" - ), - // Qwen/Qwen2.5-0.5B-Instruct - Test( - chatTemplate: - "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - data: [ - "messages": messages - ], - target: - "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" - ), - ] - - func testDefaultTemplates() throws { - for test in defaultTemplates { - let template = try Template(test.chatTemplate) - let result = try template.render(test.data) - XCTAssertEqual(result.debugDescription, test.target.debugDescription) - } - } -} diff --git a/Tests/InterpreterTests.swift b/Tests/InterpreterTests.swift index d402f84..631d2e6 100644 --- a/Tests/InterpreterTests.swift +++ b/Tests/InterpreterTests.swift @@ -141,17 +141,18 @@ final class InterpreterTests: XCTestCase { for test in tests { let env = Environment() try env.set(name: "True", value: true) - for (key, value) in test.data { try env.set(name: key, value: value) } - let tokens = try tokenize(test.template, options: test.options) let parsed = try parse(tokens: tokens) let interpreter = Interpreter(env: env) - let result = try interpreter.run(program: parsed).value as! String - - XCTAssertEqual(result.debugDescription, test.target.debugDescription) + let result = try interpreter.run(program: parsed) + if let stringResult = result as? StringValue { + XCTAssertEqual(stringResult.value.debugDescription, test.target.debugDescription) + } else { + XCTFail("Expected a StringValue, but got \(type(of: result))") + } } } } diff --git a/Tests/Template tests/ChatTemplateTests.swift b/Tests/Template tests/ChatTemplateTests.swift new file mode 100644 index 0000000..95d8d05 --- /dev/null +++ b/Tests/Template tests/ChatTemplateTests.swift @@ -0,0 +1,878 @@ +// +// ChatTemplateTests.swift +// +// +// Created by John Mai on 2024/3/24. +// + +import XCTest + +@testable import Jinja + +final class ChatTemplateTests: XCTestCase { + let messages: [[String: String]] = [ + [ + "role": "user", + "content": "Hello, how are you?", + ], + [ + "role": "assistant", + "content": "I'm doing great. How can I help you today?", + ], + [ + "role": "user", + "content": "I'd like to show off how chat templating works!", + ], + ] + + lazy var messagesWithSystemPrompt: [[String: String]] = + [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ] + ] + messages + + func testGenericChatTemplate() throws { + let chatTemplate = + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": false, + ]) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testGenericChatTemplate failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testFacebookBlenderbot400MDistill() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" + + if target != result { + print("::: testFacebookBlenderbot400MDistill failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testFacebookBlenderbotSmall90M() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" + + if target != result { + print("::: testFacebookBlenderbotSmall90M failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testBigscienceBloom() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!" + + if target != result { + print("::: testBigscienceBloom failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testEleutherAIGptNeox20b() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + + if target != result { + print("::: testEleutherAIGptNeox20b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testGPT2() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + + if target != result { + print("::: testGPT2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer1() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt, + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testHfInternalTestingLlamaTokenizer1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer2() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nDEFAULT_SYSTEM_MESSAGE\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testHfInternalTestingLlamaTokenizer2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer3() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": "<>\nYou are a helpful assistant\n<> Hello, how are you?", + ], + [ + "role": "assistant", + "content": "I'm doing great. How can I help you today?", + ], + [ + "role": "user", + "content": "I'd like to show off how chat templating works!", + ], + ], + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nYou are a helpful assistant\n<> Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testHfInternalTestingLlamaTokenizer3 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testOpenaiWhisperLargeV3() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + + if target != result { + print("::: testOpenaiWhisperLargeV3 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat1() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": true, + ]) + let target = + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" + + if target != result { + print("::: testQwenQwen1_5_1_8BChat1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat2() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt, + "add_generation_prompt": true, + ]) + let target = + "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" + + if target != result { + print("::: testQwenQwen1_5_1_8BChat2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat3() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt + ]) + let target = + "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" + + if target != result { + print("::: testQwenQwen1_5_1_8BChat3 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testTHUDMChatglm36b() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt + ]) + let target = + "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" + + if target != result { + print("::: testTHUDMChatglm36b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testGoogleGemma2bIt() throws { + let chatTemplate = + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages + ]) + let target = + "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" + + if target != result { + print("::: testGoogleGemma2bIt failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen2_5_0_5BInstruct() throws { + let chatTemplate = + "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages + ]) + let target = + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testQwenQwen2_5_0_5BInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bBetaAddGenerationPromptFalse() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt, "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n" + + if target != result { + print("::: testHuggingFaceH4Zephyr7bBetaAddGenerationPromptFalse failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bBetaAddGenerationPromptTrue() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ], + ["role": "user", "content": "How many helicopters can a human eat in one sitting?"], + ], "eos_token": "", "add_generation_prompt": true, + ] as [String: Any] + ) + let target = + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHow many helicopters can a human eat in one sitting?\n<|assistant|>\n" + + if target != result { + print("::: testHuggingFaceH4Zephyr7bBetaAddGenerationPromptTrue failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bGemmaV0_1() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testHuggingFaceH4Zephyr7bGemmaV0_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testTheBlokeMistral7BInstructV0_1GPTQ() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testTheBlokeMistral7BInstructV0_1GPTQ failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistralaiMixtral8x7BInstructV0_1() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testMistralaiMixtral8x7BInstructV0_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testCognitivecomputationsDolphin2_5Mixtral8x7b() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testCognitivecomputationsDolphin2_5Mixtral8x7b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testOpenchatOpenchat3_5_0106() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>" + + if target != result { + print("::: testOpenchatOpenchat3_5_0106 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testUpstageSOLAR10_7BInstructV1_0() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "### User:\nHello, how are you?\n\n### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!\n\n" + + if target != result { + print("::: testUpstageSOLAR10_7BInstructV1_0 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testCodellamaCodeLlama70bInstructHf() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}"; + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n " + + if target != result { + print("::: testCodellamaCodeLlama70bInstructHf failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testDeciDeciLM7BInstruct() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "### User:\nHello, how are you?\n### Assistant:\nI'm doing great. How can I help you today?\n### User:\nI'd like to show off how chat templating works!\n" + + if target != result { + print("::: testDeciDeciLM7BInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_72BChat() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testQwenQwen1_5_72BChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testDeepseekAiDeepseekLlm7bChat() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "<|begin of sentence|>", + "eos_token": "<|end of sentence|>", + ] as [String: Any] + ) + let target = + "<|begin of sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end of sentence|>User: I'd like to show off how chat templating works!\n\n" + + if target != result { + print("::: testDeepseekAiDeepseekLlm7bChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testH2oaiH2oDanube1_8bChat() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "eos_token": "", + ] as [String: Any] + ) + let target = + "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!" + + if target != result { + print("::: testH2oaiH2oDanube1_8bChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testInternlmInternlm2Chat7b() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + + if target != result { + print("::: testInternlmInternlm2Chat7b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testTheBlokedeepseekCoder33BInstructAWQ() throws { + let chatTemplate = + "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n" + + if target != result { + print("::: testTheBlokedeepseekCoder33BInstructAWQ failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testEriczzzFalconRw1bChat() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'].strip() }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "eos_token": "<|endoftext|>", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!" + + if target != result { + print("::: testEriczzzFalconRw1bChat failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testAbacusaiSmaug34BV0_1() throws { + let chatTemplate = + "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + + if target != result { + print("::: testAbacusaiSmaug34BV0_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMaywellSynatraMixtral8x7B() throws { + let chatTemplate = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + + if target != result { + print("::: testMaywellSynatraMixtral8x7B failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testDeepseekAiDeepseekCoder33bInstruct() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "<|begin of sentence|>", "eos_token": "<|EOT|>", + ] as [String: Any] + ) + let target = + "<|begin of sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n" + + if target != result { + print("::: testDeepseekAiDeepseekCoder33bInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMaywellSynatraMixtral8x7B_2() throws { + let chatTemplate = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt + ] as [String: Any] + ) + let target = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\nYou are a friendly chatbot who always responds in the style of a pirate### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + + if target != result { + print("::: testMaywellSynatraMixtral8x7B_2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistralNemoInstruct2407() throws { + let chatTemplate = + "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{%- for message in loop_messages | rejectattr(\"role\", \"equalto\", \"tool\") | rejectattr(\"role\", \"equalto\", \"tool_results\") | selectattr(\"tool_calls\", \"undefined\") %}\n {%- if (message[\"role\"] == \"user\") != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message[\"role\"] == \"tool_calls\" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "bos_token": "", + "eos_token": "", + ]) + let target = + "[INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?[INST]I'd like to show off how chat templating works![/INST]" + + XCTAssertEqual(result, target) + } + + func testQwen2VLTextOnly() throws { + let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + Hello, how are you?<|im_end|> + <|im_start|>assistant + I'm doing great. How can I help you today?<|im_end|> + <|im_start|>user + I'd like to show off how chat templating works!<|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testQwen2VLTextOnly failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +} diff --git a/Tests/Template tests/ChatTemplates.swift b/Tests/Template tests/ChatTemplates.swift new file mode 100644 index 0000000..f7e72e1 --- /dev/null +++ b/Tests/Template tests/ChatTemplates.swift @@ -0,0 +1,21 @@ +// +// ChatTemplates.swift +// Jinja +// +// Created by Anthony DePasquale on 02.01.2025. +// + +struct ChatTemplate { + static let llama3_1 = """ + {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + """ + static let llama3_2 = """ + {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + """ + static let qwen2_5 = """ + {%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n + """ + static let mistral7b = """ + {{bos_token}}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{'[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]'}}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% elif message['role'] == 'tool_results' %}{{'[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]'}}{% elif message['role'] == 'tool_calls' %}{{'[TOOL_CALLS]' + message['content']|string + eos_token}}{% endif %}{% endfor %} + """ +} diff --git a/Tests/Template tests/Messages.swift b/Tests/Template tests/Messages.swift new file mode 100644 index 0000000..2159be3 --- /dev/null +++ b/Tests/Template tests/Messages.swift @@ -0,0 +1,15 @@ +// +// Messages.swift +// Jinja +// +// Created by Anthony DePasquale on 02.01.2025. +// + +struct Messages { + static let weatherQuery: [[String: String]] = [ + [ + "role": "user", + "content": "What is the weather in Paris today?", + ] + ] +} diff --git a/Tests/Template tests/ToolSpecs.swift b/Tests/Template tests/ToolSpecs.swift new file mode 100644 index 0000000..daadbce --- /dev/null +++ b/Tests/Template tests/ToolSpecs.swift @@ -0,0 +1,48 @@ +// +// ToolSpecs.swift +// Jinja +// +// Created by Anthony DePasquale on 02.01.2025. +// + +import OrderedCollections + +struct ToolSpec { + static let getCurrentWeather = OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather") as (String, Any), + ("description", "Get the current weather in a given location") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The city and state, e.g. San Francisco, CA") + as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) +} diff --git a/Tests/Template tests/ToolUseTests.swift b/Tests/Template tests/ToolUseTests.swift new file mode 100644 index 0000000..2d8f700 --- /dev/null +++ b/Tests/Template tests/ToolUseTests.swift @@ -0,0 +1,763 @@ +// +// VisionTests.swift +// Jinja +// +// Created by Anthony DePasquale on 30.12.2024. +// + +import XCTest +import OrderedCollections + +/* + Recent models that don't support tool use: + - Gemma 2 + - Phi 3.5 + - Mistral NeMo + */ + +@testable import Jinja + +final class ToolUseTests: XCTestCase { + let messagesWithFunctionCalling: [[String: Any?]] = [ + [ + "role": "assistant", + "content": nil, + "tool_calls": [ + [ + "type": "function", + "function": [ + "name": "get_current_weather", + "arguments": "{\n \"location\": \"Hanoi\"\n}", + ], + ] + ], + ], + [ + "role": "user", + "content": "What's the weather like in Hanoi?", + ], + ] + + // Example adapted from https://huggingface.co/fireworks-ai/firefunction-v1 + let exampleFunctionSpec: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_stock_price") as (String, Any), + ("description", "Get the current stock price") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "symbol", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The stock symbol, e.g. AAPL, GOOG") as (String, Any), + ]) + ) + ]) + ) as (String, Any), + ("required", ["symbol"]) as (String, Any), + ]) + ) as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "check_word_anagram") as (String, Any), + ("description", "Check if two words are anagrams of each other") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "word1", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The first word") as (String, Any), + ]) + ) as (String, Any), + ( + "word2", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The second word") as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["word1", "word2"]) as (String, Any), + ]) + ) as (String, Any), + ]), + ] + + lazy var messagesWithFunctionCallingAndSystemPrompt: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "system") as (String, Any), + ("content", "You are a helpful assistant with access to functions. Use them if required.") as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "functions") as (String, Any), + ("content", exampleFunctionSpec) as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "user") as (String, Any), + ("content", "Hi, can you tell me the current stock price of AAPL?") as (String, Any), + ]), + ] + + let exampleToolJSONSchemas: OrderedDictionary> = OrderedDictionary( + uniqueKeysWithValues: [ + ( + "get_current_weather", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather") as (String, Any), + ("description", "Get the current weather in a given location") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The city and state, e.g. San Francisco, CA") + as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_temperature_v1", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_temperature") as (String, Any), + ("description", "Get the current temperature at a location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ( + "description", + "The current temperature at the specified location in the specified units, as a float." + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_temperature_v2", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_temperature") as (String, Any), + ("description", "Get the current temperature at a location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ("description", "The unit to return the temperature in.") + as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location", "unit"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ( + "description", + "The current temperature at the specified location in the specified units, as a float." + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_wind_speed", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_wind_speed") as (String, Any), + ("description", "Get the current wind speed in km/h at a given location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ("description", "The current wind speed at the given location in km/h, as a float.") + as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ]) + + lazy var exampleListOfTools: [OrderedDictionary] = [ + exampleToolJSONSchemas["get_current_temperature_v2"]!, + exampleToolJSONSchemas["get_current_wind_speed"]!, + ] + + func testMeetKaiFunctionaryMediumV2_2() throws { + let chatTemplate = """ + {#v2.2#}\n{% for message in messages %}\n{% if message['role'] == 'user' or message['role'] == 'system' %}\n{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% elif message['role'] == 'tool' %}\n{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% else %}\n{% set contain_content='no'%}\n{% if message['content'] is not none %}\n{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}{% set contain_content='yes'%}\n{% endif %}\n{% if 'tool_calls' in message and message['tool_calls'] is not none %}\n{% for tool_call in message['tool_calls'] %}\n{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}\n{% if loop.index == 1 and contain_content == "no" %}\n{{ prompt }}{% else %}\n{{ '\n' + prompt}}{% endif %}\n{% endfor %}\n{% endif %}\n{{ '<|stop|>\n' }}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithFunctionCalling, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = + """ + <|from|>assistant\n<|recipient|>get_current_weather\n<|content|>{\n "location": "Hanoi"\n}<|stop|>\n<|from|>user\n<|recipient|>all\n<|content|>What's the weather like in Hanoi?\n + """ + + if target != result { + print("::: testMeetKaiFunctionaryMediumV2_2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testFireworksAIFireFunctionV1() throws { + let chatTemplate = """ + {%- set message_roles = ['SYSTEM', 'FUNCTIONS', 'USER', 'ASSISTANT', 'TOOL'] -%}\n{%- set ns = namespace(seen_non_system=false, messages=messages, content='', functions=[]) -%}\n{{ bos_token }}\n{#- Basic consistency checks -#}\n{%- if not ns.messages -%}\n {{ raise_exception('No messages') }}\n{%- endif -%}\n{%- if ns.messages[0]['role'] | upper != 'SYSTEM' -%}\n {%- set ns.messages = [{'role': 'SYSTEM', 'content': 'You are a helpful assistant with access to functions. Use them if required.'}] + ns.messages -%}\n{%- endif -%}\n{%- if ns.messages | length < 2 or ns.messages[0]['role'] | upper != 'SYSTEM' or ns.messages[1]['role'] | upper != 'FUNCTIONS' -%}\n {{ raise_exception('Expected either "functions" or ["system", "functions"] as the first messages') }}\n{%- endif -%}\n{%- for message in ns.messages -%}\n {%- set role = message['role'] | upper -%}\n {#- Validation -#}\n {%- if role not in message_roles -%}\n {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles + ' are supported.') }}\n {%- endif -%}\n {%- set ns.content = message['content'] if message.get('content') else '' -%}\n {#- Move tool calls inside the content -#}\n {%- if 'tool_calls' in message -%}\n {%- for call in message['tool_calls'] -%}\n {%- set ns.content = ns.content + '{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}' -%}\n {%- endfor -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' and '' not in ns.content -%}\n {%- set ns.content = '' + ns.content -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' -%}\n {%- set ns.content = ns.content + eos_token -%}\n {%- endif -%}\n {{ role }}: {{ ns.content }}{{ '\\n\\n' }}\n{%- endfor -%}\nASSISTANT:{{ ' ' }}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithFunctionCallingAndSystemPrompt, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = """ + SYSTEM: You are a helpful assistant with access to functions. Use them if required. + + FUNCTIONS: [{"name": "get_stock_price", "description": "Get the current stock price", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "The stock symbol, e.g. AAPL, GOOG"}}, "required": ["symbol"]}}, {"name": "check_word_anagram", "description": "Check if two words are anagrams of each other", "parameters": {"type": "object", "properties": {"word1": {"type": "string", "description": "The first word"}, "word2": {"type": "string", "description": "The second word"}}, "required": ["word1", "word2"]}}] + + USER: Hi, can you tell me the current stock price of AAPL? + + ASSISTANT: + """ + + if target != result { + print("::: testFireworksAIFireFunctionV1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + // Fails because tools are omitted in the output, and the result is indented. + // func testMistral7BInstructV0_3JSONSchema() throws { + // let chatTemplate = + // "{{- bos_token }}\n{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {%- if tools and (message == user_messages[-1]) %}\n {{- ' [AVAILABLE_TOOLS] [' }}\n {%- for tool in tools %}\n\t\t{%- set tool = tool.function %}\n\t\t{{- '{\"type\": \"function\", \"function\": {' }}\n\t\t{%- for key, val in tool|items if key != \"return\" %}\n\t\t {%- if val is string %}\n\t\t\t{{- '\"' + key + '\": \"' + val + '\"' }}\n\t\t {%- else %}\n\t\t\t{{- '\"' + key + '\": ' + val|tojson }}\n\t\t {%- endif %}\n\t\t {%- if not loop.last %}\n\t\t\t{{- \", \" }}\n\t\t {%- endif %}\n\t\t{%- endfor %}\n\t\t{{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- ' [/AVAILABLE_TOOLS]' }}\n {%- endif %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- elif message['role'] == 'assistant' %}\n {%- if message.tool_calls is defined and message.tool_calls|length > 0 %}\n {{- ' [TOOL_CALLS] [' }}\n {%- for tool_call in message.tool_calls %}\n {{- {\"name\": tool_call.function.name, \"arguments\": tool_call.function.arguments, \"id\": tool_call.id}|tojson }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- '] ' }}\n {{- eos_token }}\n \t{%- elif message.content is defined %}\n\t {{- ' ' + message.content + ' ' + eos_token}}\n {%- endif %}\n {%- elif message['role'] == 'tool' %}\n {{- ' [TOOL_RESULTS] ' }}\n {{- '{\"call_id\": \"' + message.tool_call_id + '\", \"content\": ' + message.content|string + '}' }}\n {{- ' [/TOOL_RESULTS] ' }}\n {%- endif %}\n{%- endfor %}\n" + // let template = try Template(chatTemplate) + // + // let result = try template.render([ + // "messages": [ + // [ + // "role": "system", + // "content": + // "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.", + // ], + // ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + // [ + // "role": "assistant", + // "tool_calls": [ + // [ + // "id": "abcdef123", + // "type": "function", + // "function": [ + // "name": "get_current_temperature", + // "arguments": ["location": "Paris, France", "unit": "celsius"], + // ], + // ] + // ], + // ], + // ["role": "tool", "tool_call_id": "abcdef123", "name": "get_current_temperature", "content": "22.0"], + // ], + // "tools": exampleListOfTools, + // // "tools_json": "", // TODO: Figure out how to convert the array of OrderedDictionaries to JSON + // "bos_token": "", + // "eos_token": "", + // ]) + // let target = """ + // [AVAILABLE_TOOLS] [{"type": "function", "function": {"name": "get_current_temperature", "description": "Get the current temperature at a location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit to return the temperature in."}}, "required": ["location", "unit"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}}, "required": ["location"]}}}] [/AVAILABLE_TOOLS] [INST] Hey, what\'s the temperature in Paris right now? [/INST] [TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "abcdef123"}] [TOOL_RESULTS] {"call_id": "abcdef123", "content": 22.0} [/TOOL_RESULTS] + // """ + // + // if target != result { + // print("::: testMistral7BInstructV0_3JSONSchema failed.") + // print("::: target:") + // print(target) + // print("::: result:") + // print(result) + // } + // XCTAssertEqual(result, target) + // } + + // Previously failed because tools are omitted in the output, now fails because of error with `map`: runtime("map filter requires either an attribute name or a function") + // func testCISCaiMistral7BInstructV0_3SOTAGGUF() throws { + // let chatTemplate = """ + // {{ bos_token }}{% set ns = namespace(lastuser=-1, system=false, functions=false) %}{% if tools %}{% for message in messages %}{% if message['role'] == 'user' %}{% set ns.lastuser = loop.index0 %}{% elif message['role'] == 'system' %}{% set ns.system = message['content'] %}{% endif %}{% endfor %}{% set ns.functions = tools|selectattr('type','eq','function')|map(attribute='function')|list|tojson %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{% if loop.index0 == ns.lastuser and ns.functions %}{{ '[AVAILABLE_TOOLS] ' }}{{ ns.functions }}{{ '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST] ' }}{% if loop.index0 == ns.lastuser and ns.system %}{{ ns.system + ' ' }}{% endif %}{{ message['content'] }}{{ '[/INST]' }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS] ' }}{{ dict(call_id=message['tool_call_id'], content=message['content'])|tojson }}{{ '[/TOOL_RESULTS]' }}{% elif message['role'] == 'assistant' %}{% if message['tool_calls'] %}{{ '[TOOL_CALLS] [' }}{% for call in message['tool_calls'] %}{% if call['type'] == 'function' %}{{ dict(id=call['id'], name=call['function']['name'], arguments=call['function']['arguments'])|tojson }}{% endif %}{% if not loop.last %}{{ ', ' }}{% endif %}{% endfor %}{{ ']' }}{% else %}{{ message['content'] }}{% endif %}{{ eos_token }}{% endif %}{% endfor %} + // """ + // let template = try Template(chatTemplate) + // + // let result = try template.render([ + // "messages": [ + // [ + // "role": "user", + // "content": "What's the weather like in Oslo and Stockholm?", + // ] + // ], + // "tools": [exampleToolJSONSchemas["get_current_temperature_v2"]!], + // "bos_token": "", + // "eos_token": "", + // ]) + // let target = + // """ + // [AVAILABLE_TOOLS] [{"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}][/AVAILABLE_TOOLS][INST] What's the weather like in Oslo and Stockholm?[/INST] + // """ + // + // if target != result { + // print("::: testCISCaiMistral7BInstructV0_3SOTAGGUF failed.") + // print("::: target:") + // print(target) + // print("::: result:") + // print(result) + // } + // XCTAssertEqual(result, target) + // } + + func testNousResearchHermes2ProLlama38BJSONSchema() throws { + let chatTemplate = """ + {%- macro json_to_python_type(json_spec) %}\n{%- set basic_type_map = {\n "string": "str",\n "number": "float",\n "integer": "int",\n "boolean": "bool"\n} %}\n\n{%- if basic_type_map[json_spec.type] is defined %}\n {{- basic_type_map[json_spec.type] }}\n{%- elif json_spec.type == "array" %}\n {{- "list[" + json_to_python_type(json_spec|items) + "]"}}\n{%- elif json_spec.type == "object" %}\n {%- if json_spec.additionalProperties is defined %}\n {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}}\n {%- else %}\n {{- "dict" }}\n {%- endif %}\n{%- elif json_spec.type is iterable %}\n {{- "Union[" }}\n {%- for t in json_spec.type %}\n {{- json_to_python_type({"type": t}) }}\n {%- if not loop.last %}\n {{- "," }} \n {%- endif %}\n {%- endfor %}\n {{- "]" }}\n{%- else %}\n {{- "Any" }}\n{%- endif %}\n{%- endmacro %}\n\n\n{{- bos_token }}\n{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }}\n{%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- '{"type": "function", "function": ' }}\n {{- '{"name": ' + tool.name + '", ' }}\n {{- '"description": "' + tool.name + '(' }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- param_name + ": " + json_to_python_type(param_fields) }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- if tool.return is defined %}\n {{- " -> " + json_to_python_type(tool.return) }}\n {%- endif %}\n {{- " - " + tool.description + "\\n\\n" }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {%- if loop.first %}\n {{- " Args:\\n" }}\n {%- endif %}\n {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}\n {%- endfor %}\n {%- if tool.return is defined and tool.return.description is defined %}\n {{- "\\n Returns:\\n " + tool.return.description }}\n {%- endif %}\n {{- '"' }}\n {{- ', "parameters": ' }}\n {%- if tool.parameters.properties | length == 0 %}\n {{- "{}" }}\n {%- else %}\n {{- tool.parameters | tojson}}\n {%- endif %}\n {{- "}" }}\n {%- if not loop.last %}\n {{- "\\n" }}\n {%- endif %}\n{%- endfor %}\n{{- " " }}\n{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\n' }}\n{{- "For each function call return a json object with function name and arguments within XML tags as follows:\n" }}\n{{- "\n" }}\n{{- '{"arguments": , "name": }\n' }}\n{{- '<|im_end|>' }}\n{%- for message in messages %}\n {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == "assistant" %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '{ ' }}\n {%- if tool_call.arguments is defined %}\n {{- '"arguments": ' }}\n {{- tool_call.arguments|tojson }}\n {{- ', '}}\n {%- endif %}\n {{- '"name": "' }}\n {{- tool_call.name }}\n {{- '"}' }}\n {{- '\\n ' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == "tool" %}\n {%- if not message.name is defined %}\n {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }}\n {%- endif %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {{- '{"name": "' }}\n {{- message.name }}\n {{- '", "content": ' }}\n {{- message.content|tojson + '}' }}\n {{- '\\n <|im_end|>\\n' }} \n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "user") as (String, Any), + ("content", "Fetch the stock fundamentals data for Tesla (TSLA)") as (String, Any), + ]) + ], + "tools": [ + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_stock_fundamentals") as (String, Any), + ("description", "Get fundamental data for a given stock symbol using yfinance API.") + as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "symbol", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The stock symbol.") as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["symbol"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "description", + """ + A dictionary containing fundamental data. + + Keys: + - 'symbol': The stock symbol. + - 'company_name': The long name of the company. + - 'sector': The sector to which the company belongs. + - 'industry': The industry to which the company belongs. + - 'market_cap': The market capitalization of the company. + - 'pe_ratio': The forward price-to-earnings ratio. + - 'pb_ratio': The price-to-book ratio. + - 'dividend_yield': The dividend yield. + - 'eps': The trailing earnings per share. + - 'beta': The beta value of the stock. + - '52_week_high': The 52-week high price of the stock. + - '52_week_low': The 52-week low price of the stock. + """ + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ], + "bos_token": "<|begin_of_text|>", + "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|>You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\n\n Args:\n symbol(str): The stock symbol.\n Returns:\n A dictionary containing fundamental data.\n\nKeys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "The stock symbol."}}, "required": ["symbol"]}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\nFor each function call return a json object with function name and arguments within XML tags as follows:\n\n{"arguments": , "name": }\n<|im_end|><|im_start|>user\nFetch the stock fundamentals data for Tesla (TSLA)<|im_end|>\n<|im_start|>assistant\n + """ + + if target != result { + print("::: testNousResearchHermes2ProLlama38BJSONSchema failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + // func testMetaLlamaLlama3_18BInstruct() throws { + // let chatTemplate = """ + // {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = "26 Jul 2024" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\\n\\n"}}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- "<|python_tag|>" + tool_call.name + ".call(" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '="' + arg_val + '"' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{"name": "' + tool_call.name + '", ' }}\n {{- '"parameters": ' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- "<|eom_id|>" }}\n {%- else %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + // """ + // let template = try Template(chatTemplate) + // let result = try template.render([ + // "messages": [ + // ["role": "system", "content": "You are a bot that responds to weather queries."], + // ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + // ], + // "tools": [exampleToolJSONSchemas["get_current_temperature_v1"]!], + // "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + // "add_generation_prompt": true, + // ]) + // let target = """ + // <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.\n\n{\n "type": "function",\n "function": {\n "name": "get_current_temperature",\n "description": "Get the current temperature at a location.",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The location to get the temperature for, in the format \\"City, Country\\""\n }\n },\n "required": [\n "location"\n ]\n },\n "return": {\n "type": "number",\n "description": "The current temperature at the specified location in the specified units, as a float."\n }\n }\n}\n\nHey, what's the temperature in Paris right now?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n + // """ + // + // if target != result { + // print("::: testMetaLlamaLlama3_18BInstruct failed.") + // print("::: target:") + // print(target) + // print("::: result:") + // print(result) + // } + // XCTAssertEqual(result, target) + // } + + // + + func testLlama3_1() throws { + let chatTemplate = ChatTemplate.llama3_1 + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + <|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What is the weather in Paris today?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + + if target != result { + print("::: testLlamaLlama3_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testLlama3_2() throws { + let chatTemplate = ChatTemplate.llama3_2 + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + <|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What is the weather in Paris today?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + + if target != result { + print("::: testLlamaLlama3_1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwen2_5() throws { + let chatTemplate = ChatTemplate.qwen2_5 + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|im_start|>system + You are Qwen, created by Alibaba Cloud. You are a helpful assistant. + + # Tools + + You may call one or more functions to assist with the user query. + + You are provided with function signatures within XML tags: + + {"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}} + + + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + <|im_end|> + <|im_start|>user + What is the weather in Paris today?<|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: qwen2_5 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistral7b() throws { + let chatTemplate = ChatTemplate.mistral7b + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]What is the weather in Paris today?[/INST] + """ + + if target != result { + print("::: testMistral7b failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +} + +extension Data { + var string: String? { + return String(data: self, encoding: .utf8) + } +} diff --git a/Tests/Template tests/VisionTests.swift b/Tests/Template tests/VisionTests.swift new file mode 100644 index 0000000..fcdb5b9 --- /dev/null +++ b/Tests/Template tests/VisionTests.swift @@ -0,0 +1,297 @@ +// +// VisionTests.swift +// Jinja +// +// Created by Anthony DePasquale on 31.12.2024. +// + +import XCTest +import OrderedCollections + +@testable import Jinja + +final class VisionTests: XCTestCase { + let llama3_2visionChatTemplate = + "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + + func testLlama3_2_11BVisionInstructTextChatOnly() throws { + let template = try Template(llama3_2visionChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "Hello, how are you?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "assistant", + "content": [ + [ + "type": "text", + "text": "I'm doing great. How can I help you today?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "I'd like to show off how chat templating works!", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "date_string": "26 Jul 2024" as Any, + "tools_in_user_message": true as Any, + "system_message": "You are a helpful assistant." as Any, + "add_generation_prompt": true as Any, + ]) + let target = + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + if target != result { + print("::: testLlama3_2_11BVisionInstructTextChatOnly failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testLlama3_2_11BVisionInstructWithImages() throws { + let template = try Template(llama3_2visionChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: Any], + [ + "type": "image", + "image": "base64_encoded_image_data", + ] as [String: Any], + ] as [[String: Any]], + ] as [String: Any] + ] as [[String: Any]], + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + ]) + let target = + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat's in this image?<|image|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + if target != result { + print("::: testLlama3_2_11BVisionInstructWithImages failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwen2VLWithImages() throws { + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: String], + [ + "type": "image", + "image_url": "example.jpg", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's in this image?Picture 1: <|vision_start|><|image_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testQwen2VLWithImages failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testQwen2VLWithVideo() throws { + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's happening in this video?", + ] as [String: String], + [ + "type": "video", + "video_url": "example.mp4", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's happening in this video?Video 1: <|vision_start|><|video_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + + if target != result { + print("::: testQwen2VLWithVideo failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testLlama3_2_11BVisionInstructWithTools() throws { + let template = try Template(llama3_2visionChatTemplate) + + let tools: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function" as Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather" as Any), + ("description", "Get the current weather in a given location" as Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object" as Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string" as Any), + ("description", "The city and state, e.g. San Francisco, CA" as Any), + ]) as Any + ), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string" as Any), + ("enum", ["celsius", "fahrenheit"] as Any), + ]) as Any + ), + ]) as Any + ), + ("required", ["location"] as Any), + ]) as Any + ), + ]) as Any + ), + ]) + ] + + let result = try template.render([ + "messages": [ + [ + "role": "system", + "content": "You are a helpful assistant.", + ], + [ + "role": "user", + "content": "What's the weather like in San Francisco?", + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + "tools": tools as Any, + "tools_in_user_message": true as Any, + ]) + let target = """ + + <|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What's the weather like in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + + if target != result { + print("::: testLlama3_2_11BVisionInstructWithTools failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +}