Skip to content

Commit

Permalink
Stop using _io when throwing validation errors from _check()
Browse files Browse the repository at this point in the history
Relevant tests were added in these commits:

- Java: kaitai-io/kaitai_struct_tests@e92fb33
- Python: kaitai-io/kaitai_struct_tests@e7869f0

This commit is needed for the `testCheckBadValidOldIo` /
`test_check_bad_valid_old_io` test methods to pass.

The _check() method is intended to verify pure data consistency and is
supposed to be called at the time when the actual `_io` is not available
yet (or is not in the correct state) and should not be used even if it's
not `null`. Before this commit, if we wanted to initialize a KS object
by reading an existing stream and then edit the data and write them,
_check() would read the position from the old `_io` used for reading and
report it in the validation error, which is wrong.
  • Loading branch information
generalmimon committed Sep 25, 2023
1 parent cb0c1eb commit 0179199
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,17 +580,21 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = CSharpCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"if (!(${translator.translate(checkExpr)}))")
out.puts("{")
out.inc
out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);")
out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});")
out.dec
out.puts("}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,17 +1057,21 @@ class CppCompiler(
}

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else nullPtr,
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
importListSrc.addKaitai("kaitai/exceptions.h")
outSrc.puts(s"if (!(${translator.translate(checkExpr)})) {")
outSrc.inc
outSrc.puts(s"throw ${ksErrorName(err)}($errArgsStr);")
outSrc.puts(s"throw ${ksErrorName(err)}(${errArgsStr.mkString(", ")});")
outSrc.dec
outSrc.puts("}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,16 +531,20 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = GoCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "nil",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"if !(${translator.translate(checkExpr)}) {")
out.inc
val errInst = s"kaitai.New${err.name}($errArgsStr)"
val errInst = s"kaitai.New${err.name}(${errArgsStr.mkString(", ")})"
val noValueAndErr = translator.returnRes match {
case None => errInst
case Some(r) => s"$r, $errInst"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,16 +1132,20 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
}

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"if (!(${translator.translate(checkExpr)})) {")
out.inc
out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);")
out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});")
out.dec
out.puts("}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,16 +560,20 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = JavaScriptCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"if (!(${translator.translate(checkExpr)})) {")
out.inc
out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);")
out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});")
out.dec
out.puts("}")
}
Expand Down
17 changes: 7 additions & 10 deletions shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -402,24 +402,21 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = LuaCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsCode = errArgs.map(translator.translate)
val actualStr = expression(Ast.expr.InternalName(attr.id))
out.puts(s"if not(${translator.translate(checkExpr)}) then")
out.inc
val msg = err match {
case _: ValidationNotEqualError => {
val (expected, actual) = (
errArgsCode.lift(0).getOrElse("[expected]"),
errArgsCode.lift(1).getOrElse("[actual]")
)
s""""not equal, expected " .. $expected .. ", but got " .. $actual"""
val expectedStr = expected.get
s""""not equal, expected " .. $expectedStr .. ", but got " .. $actualStr"""
}
case _ => "\"" + ksErrorName(err) + "\""
case _ => expression(Ast.expr.Str(ksErrorName(err)))
}
out.puts(s"error($msg)")
out.dec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,16 +480,20 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = PHPCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"if (!(${translator.translate(checkExpr)})) {")
out.inc
out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);")
out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});")
out.dec
out.puts("}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,16 +887,20 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = PythonCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "None",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"if not ${translator.translate(checkExpr)}:")
out.inc
out.puts(s"raise ${ksErrorName(err)}($errArgsStr)")
out.puts(s"raise ${ksErrorName(err)}(${errArgsStr.mkString(", ")})")
out.dec
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,18 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def ksErrorName(err: KSError): String = RubyCompiler.ksErrorName(err)

override def attrValidateExpr(
attrId: Identifier,
attrType: DataType,
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
useIo: Boolean,
expected: Option[Ast.expr] = None
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
out.puts(s"raise ${ksErrorName(err)}.new($errArgsStr) if not ${translator.translate(checkExpr)}")
val errArgsStr = expected.map(expression) ++ List(
expression(Ast.expr.InternalName(attr.id)),
if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "nil",
expression(Ast.expr.Str(attr.path.mkString("/", "/", "")))
)
out.puts(s"raise ${ksErrorName(err)}.new(${errArgsStr.mkString(", ")}) if not ${translator.translate(checkExpr)}")
}

def types2class(names: List[String]) = names.map(type2class).mkString("::")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ trait CommonReads extends LanguageCompiler {
* @param attr attribute to run validations for
*/
def attrValidateAll(attr: AttrLikeSpec) =
attr.valid.foreach(valid => attrValidate(attr.id, attr, valid))
attr.valid.foreach(valid => attrValidate(attr.id, attr, valid, true))
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ trait GenericChecks extends LanguageCompiler with EveryReadIsExpression {
attr.valid.foreach { (valid) =>
typeProvider._currentIteratorType = Some(attr.dataTypeComposite)
if (bodyShouldDependOnIo.map(shouldDepend => validDependsOnIo(valid) == shouldDepend).getOrElse(true)) {
attrValidate(id, attr, valid)
attrValidate(id, attr, valid, false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ trait ValidateOps extends ExceptionNames {
val translator: AbstractTranslator
val typeProvider: ClassTypeProvider

def attrValidate(attrId: Identifier, attr: AttrLikeSpec, valid: ValidationSpec): Unit = {
def attrValidate(attrId: Identifier, attr: AttrLikeSpec, valid: ValidationSpec, useIo: Boolean): Unit = {
valid match {
case ValidationEq(expected) =>
attrValidateExprCompare(attrId, attr, Ast.cmpop.Eq, expected, ValidationNotEqualError(attr.dataTypeComposite))
attrValidateExprCompare(attrId, attr, Ast.cmpop.Eq, expected, ValidationNotEqualError(attr.dataTypeComposite), useIo)
case ValidationMin(min) =>
attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite))
attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite), useIo)
case ValidationMax(max) =>
attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite))
attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite), useIo)
case ValidationRange(min, max) =>
attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite))
attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite))
attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite), useIo)
attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite), useIo)
case ValidationAnyOf(values) =>
val bigOrExpr = Ast.expr.BoolOp(
Ast.boolop.Or,
Expand All @@ -37,15 +37,10 @@ trait ValidateOps extends ExceptionNames {
)

attrValidateExpr(
attrId,
attr.dataTypeComposite,
attr,
checkExpr = bigOrExpr,
err = ValidationNotAnyOfError(attr.dataTypeComposite),
errArgs = List(
Ast.expr.InternalName(attrId),
Ast.expr.InternalName(IoIdentifier),
Ast.expr.Str(attr.path.mkString("/", "/", ""))
)
useIo
)
case ValidationExpr(expr) =>
blockScopeHeader
Expand All @@ -56,40 +51,37 @@ trait ValidateOps extends ExceptionNames {
translator.translate(Ast.expr.InternalName(attrId))
)
attrValidateExpr(
attrId,
attr.dataTypeComposite,
attr,
expr,
ValidationExprError(attr.dataTypeComposite),
List(
Ast.expr.InternalName(attrId),
Ast.expr.InternalName(IoIdentifier),
Ast.expr.Str(attr.path.mkString("/", "/", ""))
)
useIo
)
blockScopeFooter
}
}

def attrValidateExprCompare(attrId: Identifier, attr: AttrLikeSpec, op: Ast.cmpop, expected: Ast.expr, err: KSError): Unit = {
def attrValidateExprCompare(
attrId: Identifier,
attr: AttrLikeSpec,
op: Ast.cmpop,
expected: Ast.expr,
err: KSError,
useIo: Boolean
): Unit = {
attrValidateExpr(
attrId,
attr.dataTypeComposite,
attr,
checkExpr = Ast.expr.Compare(
Ast.expr.InternalName(attrId),
op,
expected
),
err = err,
errArgs = List(
expected,
Ast.expr.InternalName(attrId),
Ast.expr.InternalName(IoIdentifier),
Ast.expr.Str(attr.path.mkString("/", "/", ""))
)
useIo = useIo,
expected = Some(expected)
)
}

def attrValidateExpr(attrId: Identifier, attrType: DataType, checkExpr: Ast.expr, err: KSError, errArgs: List[Ast.expr]): Unit = {}
def attrValidateExpr(attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, useIo: Boolean, expected: Option[Ast.expr] = None): Unit = {}
def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit
def blockScopeHeader: Unit
def blockScopeFooter: Unit
Expand Down

0 comments on commit 0179199

Please sign in to comment.