diff --git a/jvm/src/test/scala/io/kaitai/struct/translators/ExpressionValidator$Test.scala b/jvm/src/test/scala/io/kaitai/struct/translators/ExpressionValidator$Test.scala new file mode 100644 index 000000000..872441749 --- /dev/null +++ b/jvm/src/test/scala/io/kaitai/struct/translators/ExpressionValidator$Test.scala @@ -0,0 +1,181 @@ +package io.kaitai.struct.translators + +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.exprlang.Expressions +import io.kaitai.struct.exprlang.Expressions.ParseException +import io.kaitai.struct.precompile.{MethodNotFoundErrorWithArg, TypeMismatchError, WrongMethodCall} +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers._ + +class ExpressionValidator$Test extends AnyFunSpec { + val alwaysInt = TestTypeProviders.Always(CalcIntType) + val alwaysIntValidator = new ExpressionValidator(alwaysInt) + + describe("simple literals") { + describe("valid") { + it("123") { + val ex = Expressions.parse("123") + alwaysIntValidator.validate(ex) + } + + it("123.456e12") { + val ex = Expressions.parse("123.456e12") + alwaysIntValidator.validate(ex) + } + + it("\"foo\"") { + val ex = Expressions.parse("\"foo\"") + alwaysIntValidator.validate(ex) + } + } + } + + describe("integer methods") { + describe("valid") { + it("123.to_s") { + val ex = Expressions.parse("123.to_s") + alwaysIntValidator.validate(ex) + } + + it("123.to_s()") { + val ex = Expressions.parse("123.to_s()") + alwaysIntValidator.validate(ex) + } + } + + describe("broken") { + it("123.to_s(3)") { + val ex = Expressions.parse("123.to_s(3)") + val thrown = the[WrongMethodCall] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("wrong arguments to method call `to_s` on integer: expected (), got (IntNum(3))") + } + + it("123.unknown_method") { + val ex = Expressions.parse("123.unknown_method") + val thrown = the[MethodNotFoundErrorWithArg] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("don't know how to call method 'unknown_method' of object type 'integer'") + } + + it("123.unknown_method_with_param(true)") { + val ex = Expressions.parse("123.unknown_method_with_param(true)") + val thrown = the[MethodNotFoundErrorWithArg] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("don't know how to call method 'unknown_method_with_param' of object type 'integer'") + } + } + } + + describe("float methods") { + describe("valid") { + it("1.234.to_i") { + val ex = Expressions.parse("1.234.to_i") + alwaysIntValidator.validate(ex) + } + } + + describe("broken") { + it("1.234.unknown_method") { + val ex = Expressions.parse("1.234.unknown_method") + val thrown = the[MethodNotFoundErrorWithArg] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("don't know how to call method 'unknown_method' of object type 'float'") + } + + it("1.234.unknown_method_with_param(true)") { + val ex = Expressions.parse("1.234.unknown_method_with_param(true)") + val thrown = the[MethodNotFoundErrorWithArg] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("don't know how to call method 'unknown_method_with_param' of object type 'float'") + } + } + } + + describe("string methods") { + it("\"123\".to_i") { + val ex = Expressions.parse("\"123\".to_i") + alwaysIntValidator.validate(ex) + } + + it("\"123\".to_i()") { + val ex = Expressions.parse("\"123\".to_i()") + alwaysIntValidator.validate(ex) + } + + it("\"123\".to_i(16)") { + val ex = Expressions.parse("\"123\".to_i(16)") + alwaysIntValidator.validate(ex) + } + + it("\"123\".to_i(true)") { + val ex = Expressions.parse("\"123\".to_i(true)") + val thrown = the [WrongMethodCall] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("wrong arguments to method call `to_i` on string: expected () or (integer), got (Bool(true))") + } + + it("\"123\".to_i(16, true)") { + val ex = Expressions.parse("\"123\".to_i(16, true)") + val thrown = the [WrongMethodCall] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("wrong arguments to method call `to_i` on string: expected () or (integer), got (IntNum(16), Bool(true))") + } + + it("\"foobar\".substring(2, 3)") { + val ex = Expressions.parse("\"foobar\".substring(2, 3)") + alwaysIntValidator.validate(ex) + } + + it("\"foobar\".substring(2, 3, 5)") { + val ex = Expressions.parse("\"foobar\".substring(2, 3, 5)") + val thrown = the [WrongMethodCall] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("wrong arguments to method call `substring` on string: expected (integer, integer), got (IntNum(2), IntNum(3), IntNum(5))") + } + + it("\"foobar\".substring(\"foo\", 5)") { + val ex = Expressions.parse("\"foobar\".substring(\"foo\", 5)") + val thrown = the [WrongMethodCall] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("wrong arguments to method call `substring` on string: expected (integer, integer), got (Str(foo), IntNum(5))") + } + } + + describe("array methods") { + describe("valid") { + it("[\"foo\", \"bar\"].size") { + val ex = Expressions.parse("[\"foo\", \"bar\"].size") + alwaysIntValidator.validate(ex) + } + + it("[\"foo\", \"bar\"].min") { + val ex = Expressions.parse("[\"foo\", \"bar\"].min") + alwaysIntValidator.validate(ex) + } + + it("[\"foo\", \"bar\"].min()") { + val ex = Expressions.parse("[\"foo\", \"bar\"].min") + alwaysIntValidator.validate(ex) + } + } + + describe("broken") { + it("[\"foo\", \"bar\"].min(42)") { + val ex = Expressions.parse("[\"foo\", \"bar\"].min(42)") + val thrown = the[WrongMethodCall] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("wrong arguments to method call `min` on array: expected (), got (IntNum(42))") + } + } + } + + describe("subscripts") { + it("[1, 3, 14][2]") { + val ex = Expressions.parse("[1, 3, 14][2]") + alwaysIntValidator.validate(ex) + } + + it("[1, 3, 14][\"foo\"]") { + val ex = Expressions.parse("[1, 3, 14][\"foo\"]") + val thrown = the [TypeMismatchError] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("subscript operation on arrays require index to be integer, but found CalcStrType") + } + + it("x[4]") { + val ex = Expressions.parse("x[4]") + val thrown = the [TypeMismatchError] thrownBy alwaysIntValidator.validate(ex) + thrown.getMessage should be("subscript operation is not supported on object type CalcIntType") + } + } +} diff --git a/shared/src/main/scala/io/kaitai/struct/precompile/Exceptions.scala b/shared/src/main/scala/io/kaitai/struct/precompile/Exceptions.scala index 32411d9b4..1ac6035bb 100644 --- a/shared/src/main/scala/io/kaitai/struct/precompile/Exceptions.scala +++ b/shared/src/main/scala/io/kaitai/struct/precompile/Exceptions.scala @@ -2,6 +2,7 @@ package io.kaitai.struct.precompile import io.kaitai.struct.datatype.DataType import io.kaitai.struct.format.ClassSpec +import io.kaitai.struct.translators.MethodArgType /** * Base class for all expression-related errors, not localized to a certain path @@ -10,6 +11,8 @@ import io.kaitai.struct.format.ClassSpec sealed abstract class ExpressionError(msg: String) extends RuntimeException(msg) class TypeMismatchError(msg: String) extends ExpressionError(msg) class TypeUndecidedError(msg: String) extends ExpressionError(msg) +class WrongMethodCall(val dataType: MethodArgType, val methodName: String, val expectedSigs: Iterable[String], val actualSig: String) + extends ExpressionError(s"wrong arguments to method call `$methodName` on $dataType: expected ${expectedSigs.mkString(" or ")}, got $actualSig") sealed abstract class NotFoundError(msg: String) extends ExpressionError(msg) class TypeNotFoundError(val name: String, val curClass: ClassSpec) @@ -18,10 +21,15 @@ class FieldNotFoundError(val name: String, val curClass: ClassSpec) extends NotFoundError(s"unable to access '$name' in ${curClass.nameAsStr} context") class EnumNotFoundError(val name: String, val curClass: ClassSpec) extends NotFoundError(s"unable to find enum '$name', searching from ${curClass.nameAsStr}") -class EnumMemberNotFoundError(val label: String, val enum: String, val enumDefPath: String) - extends NotFoundError(s"unable to find enum member '$enum::$label' (enum '$enum' defined at /$enumDefPath)") +class EnumMemberNotFoundError(val label: String, val enumName: String, val enumDefPath: String) + extends NotFoundError(s"unable to find enum member '$enumName::$label' (enum '$enumName' defined at /$enumDefPath)") + +// TODO: get rid of MethodNotFoundError in favor of MethodNotFoundErrorWithArg, rename it back +// requires refactoring of [[TypeDetector]] class MethodNotFoundError(val name: String, val dataType: DataType) extends NotFoundError(s"don't know how to call method '$name' of object type '$dataType'") +class MethodNotFoundErrorWithArg(val name: String, val argType: MethodArgType) + extends NotFoundError(s"don't know how to call method '$name' of object type '$argType'") /** * Internal compiler logic error: should never happen, but at least we want to diff --git a/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala b/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala index b7e2e79a3..414b7645d 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala @@ -4,9 +4,137 @@ import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format.Identifier -import io.kaitai.struct.precompile.TypeMismatchError +import io.kaitai.struct.precompile.{MethodNotFoundError, MethodNotFoundErrorWithArg, TypeMismatchError, WrongMethodCall} + +sealed trait MethodArgType +object MethodArgType { + case object IntArg extends MethodArgType { + override def toString = "integer" + } + case object FloatArg extends MethodArgType { + override def toString = "float" + } + case object StrArg extends MethodArgType { + override def toString = "string" + } + case object BooleanArg extends MethodArgType { + override def toString = "boolean" + } + case object BytesArg extends MethodArgType { + override def toString = "byte array" + } + case object ArrayArg extends MethodArgType { + override def toString = "array" + } + + def byDataType(dataType: DataType): Option[MethodArgType] = { + dataType match { + case _: IntType => Some(IntArg) + case _: FloatType => Some(FloatArg) + case _: StrType => Some(StrArg) + case _: BooleanType => Some(BooleanArg) + case _: BytesType => Some(BytesArg) + case _: ArrayType => Some(ArrayArg) + case _ => None + } + } + + def isArgAcceptable(actualType: DataType, expectedType: MethodArgType): Boolean = + byDataType(actualType).map((t) => t == expectedType).getOrElse(false) +} abstract trait CommonMethods[T] extends TypeDetector { + import MethodArgType._ + + sealed trait MethodSig { + def name: String + def expectedArgs: String + def accepts(argsValues: Seq[Ast.expr]): Boolean + } + + case class MethodSig0( + name: String, + returnType: DataType, + method: (Ast.expr) => T + ) extends MethodSig { + override def expectedArgs: String = "()" + override def accepts(argsValues: Seq[Ast.expr]): Boolean = argsValues.isEmpty + } + + case class MethodSig1( + name: String, + returnType: DataType, + argTypes: MethodArgType, + method: (Ast.expr, Ast.expr) => T + ) extends MethodSig { + override def expectedArgs: String = s"($argTypes)" + override def accepts(argsValues: Seq[Ast.expr]): Boolean = argsValues match { + case Seq(arg0) => isArgAcceptable(detectType(arg0), argTypes) + case _ => false + } + } + + case class MethodSig2( + name: String, + returnType: DataType, + argTypes: (MethodArgType, MethodArgType), + method: (Ast.expr, Ast.expr, Ast.expr) => T + ) extends MethodSig { + override def expectedArgs: String = s"(${argTypes._1}, ${argTypes._2})" + override def accepts(argsValues: Seq[Ast.expr]): Boolean = argsValues match { + case Seq(arg0, arg1) => + isArgAcceptable(detectType(arg0), argTypes._1) && + isArgAcceptable(detectType(arg1), argTypes._2) + case _ => false + } + } + + val METHODS_BY_TYPE: Map[MethodArgType, List[MethodSig]] = Map( + BytesArg -> List( + MethodSig0("first", Int1Type(false), bytesFirst), + MethodSig0("last", Int1Type(false), bytesLast), + MethodSig0("length", CalcIntType, bytesLength), + MethodSig0("size", CalcIntType, bytesLength), + MethodSig0("min", Int1Type(false), bytesMin), + MethodSig0("max", Int1Type(false), bytesMax), + + // TODO: implement a better way to signal that we want not just any string, but string literal + MethodSig1("to_s", CalcStrType, StrArg, { case (obj, arg0) => + arg0 match { + case Ast.expr.Str(encoding) => + bytesToStr(obj, encoding) + case x => + throw new TypeMismatchError(s"to_s: argument #0: expected string literal, got $x") + } + }) + ), + IntArg -> List( + MethodSig0("to_s", CalcStrType, intToStr), + ), + FloatArg -> List( + MethodSig0("to_i", CalcIntType, floatToInt), + ), + StrArg -> List( + MethodSig0("length", CalcIntType, strLength), + MethodSig0("reverse", CalcStrType, strReverse), + MethodSig0("to_i", CalcIntType, { strToInt(_, Ast.expr.IntNum(10)) }), + MethodSig1("to_i", CalcIntType, IntArg, strToInt), + MethodSig2("substring", CalcStrType, (IntArg, IntArg), strSubstring) + ), + BooleanArg -> List( + MethodSig0("to_i", CalcBooleanType, boolToInt) + ), + + // TODO: do something about return type for arrays here + ArrayArg -> List( + MethodSig0("first", AnyType, arrayFirst), + MethodSig0("last", AnyType, arrayLast), + MethodSig0("size", AnyType, arraySize), + MethodSig0("min", AnyType, arrayMin), + MethodSig0("max", AnyType, arrayMax), + ), + ) + /** * Translates a certain attribute call (as in `foo.bar`) into a rendition * of expression in certain target language. @@ -30,36 +158,6 @@ abstract trait CommonMethods[T] extends TypeDetector { } case ut: UserType => userTypeField(ut, value, attr.name) - case _: BytesType => - attr.name match { - case "first" => bytesFirst(value) - case "last" => bytesLast(value) - case "length" | "size" => bytesLength(value) - case "min" => bytesMin(value) - case "max" => bytesMax(value) - } - case _: StrType => - attr.name match { - case "length" => strLength(value) - case "reverse" => strReverse(value) - case "to_i" => strToInt(value, Ast.expr.IntNum(10)) - } - case _: IntType => - attr.name match { - case "to_s" => intToStr(value) - } - case _: FloatType => - attr.name match { - case "to_i" => floatToInt(value) - } - case _: ArrayType => - attr.name match { - case "first" => arrayFirst(value) - case "last" => arrayLast(value) - case "size" => arraySize(value) - case "min" => arrayMin(value) - case "max" => arrayMax(value) - } case KaitaiStreamType | OwnedKaitaiStreamType => attr.name match { case "size" => kaitaiStreamSize(value) @@ -71,10 +169,10 @@ abstract trait CommonMethods[T] extends TypeDetector { case "to_i" => enumToInt(value, et) case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") } - case _: BooleanType => - attr.name match { - case "to_i" => boolToInt(value) - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + case _ => + MethodArgType.byDataType(valType) match { + case Some(argType) => invokeMethod(argType, attr.name, value) + case _ => throw new TypeMismatchError(s"internal compiler error: tried to call attribute '${attr.name}' on expression of type $valType") } } } @@ -82,6 +180,7 @@ abstract trait CommonMethods[T] extends TypeDetector { /** * Translates a certain function call (as in `foo.bar(arg1, arg2)`) into a * rendition of expression in certain target language. + * * @note Must be kept in sync with [[TypeDetector.detectCallType]] * @param call function call expression to translate * @return result of translation as [[T]] @@ -93,21 +192,40 @@ abstract trait CommonMethods[T] extends TypeDetector { func match { case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => val objType = detectType(obj) - (objType, methodName.name) match { - // TODO: check argument quantity - case (_: StrType, "substring") => strSubstring(obj, args(0), args(1)) - case (_: StrType, "to_i") => strToInt(obj, args(0)) - case (_: BytesType, "to_s") => - args match { - case Seq(Ast.expr.Str(encoding)) => - bytesToStr(obj, encoding) - case Seq(x) => - throw new TypeMismatchError(s"to_s: argument #0: expected string literal, got $x") - case _ => - throw new TypeMismatchError(s"to_s: expected 1 argument, got ${args.length}") + MethodArgType.byDataType(objType) match { + case Some(argType) => + invokeMethod(argType, methodName.name, obj, args) + case None => + throw new MethodNotFoundError(methodName.name, objType) + } + } + } + + private def invokeMethod(argType: MethodArgType, methodName: String, obj: Ast.expr, args: Seq[Ast.expr] = Seq()): T = { + METHODS_BY_TYPE.get(argType) match { + case Some(methodList) => + val methodSigs = methodList.filter(_.name == methodName) + if (methodSigs.isEmpty) { + throw new MethodNotFoundErrorWithArg(methodName, argType) + } else { + val expectedArgProblems: List[String] = methodSigs.map { methodSig => + if (methodSig.accepts(args)) { + return methodSig match { + case ms0: MethodSig0 => + ms0.method(obj) + case ms1: MethodSig1 => + ms1.method(obj, args(0)) + case ms2: MethodSig2 => + ms2.method(obj, args(0), args(1)) + } + } else { + methodSig.expectedArgs } - case _ => throw new TypeMismatchError(s"don't know how to call method '$methodName' of object type '$objType'") + } + throw new WrongMethodCall(argType, methodName, expectedArgProblems, "(" + args.mkString(", ") + ")") } + case None => + throw new MethodNotFoundErrorWithArg(methodName, argType) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala index 84f4fedc7..8d2741282 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala @@ -1,9 +1,10 @@ package io.kaitai.struct.translators import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.datatype.DataType.{ArrayType, BytesType, IntType} import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format.Identifier -import io.kaitai.struct.precompile.EnumMemberNotFoundError +import io.kaitai.struct.precompile.{EnumMemberNotFoundError, TypeMismatchError} /** * Validates expressions usage of types (in typecasting operator, @@ -59,8 +60,18 @@ class ExpressionValidator(val provider: TypeProvider) validate(ifTrue) validate(ifFalse) case Ast.expr.Subscript(container: Ast.expr, idx: Ast.expr) => - validate(container) - validate(idx) + detectType(container) match { + case _: ArrayType | _: BytesType => + validate(container) + detectType(idx) match { + case _: IntType => + validate(idx) + case indexType => + throw new TypeMismatchError(s"subscript operation on arrays require index to be integer, but found $indexType") + } + case x => + throw new TypeMismatchError(s"subscript operation is not supported on object type $x") + } case call: Ast.expr.Attribute => translateAttribute(call) case call: Ast.expr.Call =>