From 08277e4eccaf326c82feaa4f1f5f4a8e1f8bd143 Mon Sep 17 00:00:00 2001 From: Simon Friis Vindum Date: Thu, 29 May 2025 15:31:17 +0200 Subject: [PATCH] Rust: Refactor type equality --- .../codeql/rust/internal/TypeInference.qll | 109 +++++++----------- 1 file changed, 43 insertions(+), 66 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index fcacfd5d3dad..8399bde8aa80 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -207,81 +207,58 @@ private Type inferAssignmentOperationType(AstNode n, TypePath path) { } /** - * Holds if the type of `n1` at `path1` is the same as the type of `n2` at - * `path2` and type information should propagate in both directions through the - * type equality. + * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree + * of `n2` at `prefix2` and type information should propagate in both directions + * through the type equality. */ -bindingset[path1] -bindingset[path2] -private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { - exists(Variable v | - path1 = path2 and - n1 = v.getAnAccess() - | - n2 = v.getPat() +private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + prefix1.isEmpty() and + prefix2.isEmpty() and + ( + exists(Variable v | n1 = v.getAnAccess() | + n2 = v.getPat() + or + n2 = v.getParameter().(SelfParam) + ) or - n2 = v.getParameter().(SelfParam) - ) - or - exists(LetStmt let | - let.getPat() = n1 and - let.getInitializer() = n2 and - path1 = path2 - ) - or - n1 = n2.(ParenExpr).getExpr() and - path1 = path2 - or - n1 = n2.(BlockExpr).getStmtList().getTailExpr() and - path1 = path2 - or - n1 = n2.(IfExpr).getABranch() and - path1 = path2 - or - n1 = n2.(MatchExpr).getAnArm().getExpr() and - path1 = path2 - or - exists(BreakExpr break | - break.getExpr() = n1 and - break.getTarget() = n2.(LoopExpr) and - path1 = path2 - ) - or - exists(AssignmentExpr be | - n1 = be.getLhs() and - n2 = be.getRhs() and - path1 = path2 - ) -} - -bindingset[path1] -private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { - typeEquality(n1, path1, n2, path2) - or - n2 = - any(DerefExpr pe | - pe.getExpr() = n1 and - path1.isCons(TRefTypeParameter(), path2) + exists(LetStmt let | + let.getPat() = n1 and + let.getInitializer() = n2 ) -} - -bindingset[path2] -private predicate typeEqualityRight(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { - typeEquality(n1, path1, n2, path2) - or - n2 = - any(DerefExpr pe | - pe.getExpr() = n1 and - path1 = TypePath::cons(TRefTypeParameter(), path2) + or + n1 = n2.(ParenExpr).getExpr() + or + n1 = n2.(BlockExpr).getStmtList().getTailExpr() + or + n1 = n2.(IfExpr).getABranch() + or + n1 = n2.(MatchExpr).getAnArm().getExpr() + or + exists(BreakExpr break | + break.getExpr() = n1 and + break.getTarget() = n2.(LoopExpr) + ) + or + exists(AssignmentExpr be | + n1 = be.getLhs() and + n2 = be.getRhs() ) + ) + or + n1 = n2.(DerefExpr).getExpr() and + prefix1 = TypePath::singleton(TRefTypeParameter()) and + prefix2.isEmpty() } pragma[nomagic] private Type inferTypeEquality(AstNode n, TypePath path) { - exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) | - typeEqualityRight(n, path, n2, path2) + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) + | + typeEquality(n, prefix1, n2, prefix2) or - typeEqualityLeft(n2, path2, n, path) + typeEquality(n2, prefix2, n, prefix1) ) }