Skip to content

Commit 09bf05f

Browse files
committed
Rust: Fix types for * to deref overload
1 parent 7d536a3 commit 09bf05f

File tree

6 files changed

+169
-202
lines changed

6 files changed

+169
-202
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Impl {
3535
*/
3636
abstract class Call extends ExprImpl::Expr {
3737
/** Holds if the receiver of this call is implicitly borrowed. */
38-
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }
38+
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition(), _) }
3939

4040
/** Gets the trait targeted by this call, if any. */
4141
abstract Trait getTrait();
@@ -47,7 +47,7 @@ module Impl {
4747
abstract Expr getArgument(ArgumentPosition pos);
4848

4949
/** Holds if the argument at `pos` might be implicitly borrowed. */
50-
abstract predicate implicitBorrowAt(ArgumentPosition pos);
50+
abstract predicate implicitBorrowAt(ArgumentPosition pos, boolean certain);
5151

5252
/** Gets the number of arguments _excluding_ any `self` argument. */
5353
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }
@@ -85,7 +85,7 @@ module Impl {
8585

8686
override Trait getTrait() { none() }
8787

88-
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
88+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }
8989

9090
override Expr getArgument(ArgumentPosition pos) {
9191
result = super.getArgList().getArg(pos.asPosition())
@@ -109,7 +109,7 @@ module Impl {
109109
qualifier.toString() != "Self"
110110
}
111111

112-
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
112+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }
113113

114114
override Expr getArgument(ArgumentPosition pos) {
115115
pos.isSelf() and result = super.getArgList().getArg(0)
@@ -123,7 +123,9 @@ module Impl {
123123

124124
override Trait getTrait() { none() }
125125

126-
override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }
126+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
127+
pos.isSelf() and certain = false
128+
}
127129

128130
override Expr getArgument(ArgumentPosition pos) {
129131
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
@@ -143,10 +145,13 @@ module Impl {
143145

144146
override Trait getTrait() { result = trait }
145147

146-
override predicate implicitBorrowAt(ArgumentPosition pos) {
147-
pos.isSelf() and borrows >= 1
148-
or
149-
pos.asPosition() = 0 and borrows = 2
148+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
149+
(
150+
pos.isSelf() and borrows >= 1
151+
or
152+
pos.asPosition() = 0 and borrows = 2
153+
) and
154+
certain = true
150155
}
151156

152157
override Expr getArgument(ArgumentPosition pos) {

rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ private predicate isOverloaded(string op, int arity, string path, string method,
2222
op = "!" and path = "core::ops::bit::Not" and method = "not" and borrows = 0
2323
or
2424
// Dereference
25-
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 0
25+
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 1
2626
)
2727
or
2828
arity = 2 and

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 96 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
273273
prefix1.isEmpty() and
274274
prefix2 = TypePath::singleton(TRefTypeParameter())
275275
or
276-
n1 = n2.(DerefExpr).getExpr() and
277-
prefix1 = TypePath::singleton(TRefTypeParameter()) and
278-
prefix2.isEmpty()
279-
or
280276
exists(BlockExpr be |
281277
n1 = be and
282278
n2 = be.getStmtList().getTailExpr() and
@@ -640,20 +636,20 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
640636
}
641637

642638
private newtype TAccessPosition =
643-
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed) or
639+
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed, Boolean certain) or
644640
TReturnAccessPosition()
645641

646642
class AccessPosition extends TAccessPosition {
647-
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _) }
643+
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _, _) }
648644

649-
predicate isBorrowed() { this = TArgumentAccessPosition(_, true) }
645+
predicate isBorrowed(boolean certain) { this = TArgumentAccessPosition(_, true, certain) }
650646

651647
predicate isReturn() { this = TReturnAccessPosition() }
652648

653649
string toString() {
654-
exists(ArgumentPosition pos, boolean borrowed |
655-
this = TArgumentAccessPosition(pos, borrowed) and
656-
result = pos + ":" + borrowed
650+
exists(ArgumentPosition pos, boolean borrowed, boolean certain |
651+
this = TArgumentAccessPosition(pos, borrowed, certain) and
652+
result = pos + ":" + borrowed + ":" + certain
657653
)
658654
or
659655
this.isReturn() and
@@ -674,10 +670,15 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
674670
}
675671

676672
AstNode getNodeAt(AccessPosition apos) {
677-
exists(ArgumentPosition pos, boolean borrowed |
678-
apos = TArgumentAccessPosition(pos, borrowed) and
679-
result = this.getArgument(pos) and
680-
if this.implicitBorrowAt(pos) then borrowed = true else borrowed = false
673+
exists(ArgumentPosition pos, boolean borrowed, boolean certain |
674+
apos = TArgumentAccessPosition(pos, borrowed, certain) and
675+
result = this.getArgument(pos)
676+
|
677+
if this.implicitBorrowAt(pos, _)
678+
then borrowed = true and this.implicitBorrowAt(pos, certain)
679+
else (
680+
borrowed = false and certain = true
681+
)
681682
)
682683
or
683684
result = this and apos.isReturn()
@@ -705,51 +706,54 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
705706
predicate adjustAccessType(
706707
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
707708
) {
708-
if apos.isBorrowed()
709-
then
710-
exists(Type selfParamType |
711-
selfParamType =
712-
target
713-
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
714-
TypePath::nil())
715-
|
716-
if selfParamType = TRefType()
709+
apos.isBorrowed(true) and
710+
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
711+
tAdj = t
712+
or
713+
apos.isBorrowed(false) and
714+
exists(Type selfParamType |
715+
selfParamType =
716+
target
717+
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
718+
TypePath::nil())
719+
|
720+
if selfParamType = TRefType()
721+
then
722+
if t != TRefType() and path.isEmpty()
717723
then
718-
if t != TRefType() and path.isEmpty()
724+
// adjust for implicit borrow
725+
pathAdj.isEmpty() and
726+
tAdj = TRefType()
727+
or
728+
// adjust for implicit borrow
729+
pathAdj = TypePath::singleton(TRefTypeParameter()) and
730+
tAdj = t
731+
else
732+
if path.isCons(TRefTypeParameter(), _)
719733
then
734+
pathAdj = path and
735+
tAdj = t
736+
else (
720737
// adjust for implicit borrow
721-
pathAdj.isEmpty() and
722-
tAdj = TRefType()
723-
or
724-
// adjust for implicit borrow
725-
pathAdj = TypePath::singleton(TRefTypeParameter()) and
738+
not (t = TRefType() and path.isEmpty()) and
739+
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
726740
tAdj = t
727-
else
728-
if path.isCons(TRefTypeParameter(), _)
729-
then
730-
pathAdj = path and
731-
tAdj = t
732-
else (
733-
// adjust for implicit borrow
734-
not (t = TRefType() and path.isEmpty()) and
735-
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
736-
tAdj = t
737-
)
738-
else (
739-
// adjust for implicit deref
740-
path.isCons(TRefTypeParameter(), pathAdj) and
741-
tAdj = t
742-
or
743-
not path.isCons(TRefTypeParameter(), _) and
744-
not (t = TRefType() and path.isEmpty()) and
745-
pathAdj = path and
746-
tAdj = t
747-
)
741+
)
742+
else (
743+
// adjust for implicit deref
744+
path.isCons(TRefTypeParameter(), pathAdj) and
745+
tAdj = t
746+
or
747+
not path.isCons(TRefTypeParameter(), _) and
748+
not (t = TRefType() and path.isEmpty()) and
749+
pathAdj = path and
750+
tAdj = t
748751
)
749-
else (
750-
pathAdj = path and
751-
tAdj = t
752752
)
753+
or
754+
not apos.isBorrowed(_) and
755+
pathAdj = path and
756+
tAdj = t
753757
}
754758
}
755759

@@ -766,35 +770,47 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
766770
TypePath path0
767771
|
768772
n = a.getNodeAt(apos) and
769-
result = CallExprBaseMatching::inferAccessType(a, apos, path0) and
770-
if apos.isBorrowed()
771-
then
772-
exists(Type argType | argType = inferType(n) |
773-
if argType = TRefType()
774-
then
775-
path = path0 and
776-
path0.isCons(TRefTypeParameter(), _)
777-
or
778-
// adjust for implicit deref
773+
result = CallExprBaseMatching::inferAccessType(a, apos, path0)
774+
|
775+
(
776+
apos.isBorrowed(true)
777+
or
778+
// The desugaring of the unary `*e` is `*Deref::deref(&e)`. To handle the
779+
// deref expression after the call we must strip a `&` from the type at
780+
// the return position.
781+
apos.isReturn() and a instanceof DerefExpr
782+
) and
783+
path0.isCons(TRefTypeParameter(), path)
784+
or
785+
apos.isBorrowed(false) and
786+
exists(Type argType | argType = inferType(n) |
787+
if argType = TRefType()
788+
then
789+
path = path0 and
790+
path0.isCons(TRefTypeParameter(), _)
791+
or
792+
// adjust for implicit deref
793+
not path0.isCons(TRefTypeParameter(), _) and
794+
not (path0.isEmpty() and result = TRefType()) and
795+
path = TypePath::cons(TRefTypeParameter(), path0)
796+
else (
797+
not (
798+
argType.(StructType).asItemNode() instanceof StringStruct and
799+
result.(StructType).asItemNode() instanceof Builtins::Str
800+
) and
801+
(
779802
not path0.isCons(TRefTypeParameter(), _) and
780803
not (path0.isEmpty() and result = TRefType()) and
781-
path = TypePath::cons(TRefTypeParameter(), path0)
782-
else (
783-
not (
784-
argType.(StructType).asItemNode() instanceof StringStruct and
785-
result.(StructType).asItemNode() instanceof Builtins::Str
786-
) and
787-
(
788-
not path0.isCons(TRefTypeParameter(), _) and
789-
not (path0.isEmpty() and result = TRefType()) and
790-
path = path0
791-
or
792-
// adjust for implicit borrow
793-
path0.isCons(TRefTypeParameter(), path)
794-
)
804+
path = path0
805+
or
806+
// adjust for implicit borrow
807+
path0.isCons(TRefTypeParameter(), path)
795808
)
796809
)
797-
else path = path0
810+
)
811+
or
812+
not apos.isBorrowed(_) and
813+
path = path0
798814
)
799815
}
800816

@@ -1387,7 +1403,7 @@ private module Cached {
13871403
predicate receiverHasImplicitDeref(AstNode receiver) {
13881404
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
13891405
apos.getArgumentPosition().isSelf() and
1390-
apos.isBorrowed() and
1406+
apos.isBorrowed(_) and
13911407
receiver = a.getNodeAt(apos) and
13921408
inferType(receiver) = TRefType() and
13931409
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) != TRefType()
@@ -1399,7 +1415,7 @@ private module Cached {
13991415
predicate receiverHasImplicitBorrow(AstNode receiver) {
14001416
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
14011417
apos.getArgumentPosition().isSelf() and
1402-
apos.isBorrowed() and
1418+
apos.isBorrowed(_) and
14031419
receiver = a.getNodeAt(apos) and
14041420
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) = TRefType() and
14051421
inferType(receiver) != TRefType()

rust/ql/test/library-tests/type-inference/dereference.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn explicit_monomorphic_dereference() {
3434

3535
// Dereference with overloaded dereference operator
3636
let a2 = MyIntPointer { value: 34i64 };
37-
let _b2 = *a2; // $ method=MyIntPointer::deref MISSING: type=_b2:i64
37+
let _b2 = *a2; // $ method=MyIntPointer::deref type=_b2:i64
3838

3939
// Call method on explicitly dereferenced value
4040
let a3 = MyIntPointer { value: 34i64 };
@@ -48,11 +48,11 @@ fn explicit_polymorphic_dereference() {
4848

4949
// Explicit dereference with type parameter
5050
let c2 = MySmartPointer { value: 'a' };
51-
let _d2 = *c2; // $ method=MySmartPointer::deref MISSING: type=_d2:char
51+
let _d2 = *c2; // $ method=MySmartPointer::deref type=_d2:char
5252

5353
// Call method on explicitly dereferenced value with type parameter
5454
let c3 = MySmartPointer { value: 34i64 };
55-
let _d3 = (*c3).is_positive(); // $ method=MySmartPointer::deref MISSING: method=is_positive type=_d3:bool
55+
let _d3 = (*c3).is_positive(); // $ method=MySmartPointer::deref method=is_positive type=_d3:bool
5656
}
5757

5858
fn explicit_ref_dereference() {
@@ -76,11 +76,11 @@ fn explicit_box_dereference() {
7676

7777
// Explicit dereference with type parameter
7878
let g2: Box<char> = Box::new('a');
79-
let _h2 = *g2; // $ method=deref MISSING: type=_h2:char
79+
let _h2 = *g2; // $ method=deref type=_h2:char
8080

8181
// Call method on explicitly dereferenced value with type parameter
8282
let g3: Box<i64> = Box::new(34i64);
83-
let _h3 = (*g3).is_positive(); // $ method=deref MISSING: method=is_positive type=_h3:bool
83+
let _h3 = (*g3).is_positive(); // $ method=deref method=is_positive type=_h3:bool
8484
}
8585

8686
fn implicit_dereference() {

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,7 @@ mod method_call_type_conversion {
11401140
println!("{:?}", x5.m1()); // $ method=m1
11411141
println!("{:?}", x5.0); // $ fieldof=S
11421142

1143-
let x6 = &S(S2); // $ SPURIOUS: type=x6:&T.&T.S
1143+
let x6 = &S(S2);
11441144

11451145
// explicit dereference
11461146
println!("{:?}", (*x6).m1()); // $ method=m1 method=deref
@@ -1717,7 +1717,7 @@ mod overloadable_operators {
17171717

17181718
// Here the type of `default_vec2` must be inferred from the `==` call
17191719
// and the type of the borrowed second argument is unknown at the call.
1720-
let default_vec2 = Default::default(); // $ MISSING: type=default_vec2:Vec2
1720+
let default_vec2 = Default::default(); // $ type=default_vec2:Vec2
17211721
let vec2_zero_plus = Vec2 { x: 0, y: 0 } == default_vec2; // $ method=Vec2::eq
17221722
}
17231723
}

0 commit comments

Comments
 (0)