Skip to content

Rust: Apply inherent method prioritization inside type inference loop #19903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 51 additions & 67 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}

Declaration getTarget() {
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
or
result = CallExprImpl::getResolvedFunction(this)
}
Expand Down Expand Up @@ -1178,14 +1178,14 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
methodCandidate(type, name, arity, impl)
}

private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getNumberOfArguments()
}
pragma[nomagic]
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getNumberOfArguments()
}

private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
exists(Type rootType, string name, int arity |
Expand Down Expand Up @@ -1334,17 +1334,46 @@ private predicate methodResolutionDependsOnArgument(
)
}

/**
* Holds if the method call `mc` has no inherent target, i.e., it does not
* resolve to a method in an `impl` block for the type of the receiver.
*/
pragma[nomagic]
private predicate methodCallHasNoInherentTarget(MethodCall mc) {
exists(Type rootType, string name, int arity |
isMethodCall(mc, rootType, name, arity) and
forall(Impl impl |
methodCandidate(rootType, name, arity, impl) and
not impl.hasTrait()
|
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isNotInstantiationOf(mc, impl, _)
)
)
}

pragma[nomagic]
private predicate methodCallHasImplCandidate(MethodCall mc, Impl impl) {
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
if impl.hasTrait() and not exists(mc.getTrait())
then
// inherent methods take precedence over trait methods, so only allow
// trait methods when there are no matching inherent methods
methodCallHasNoInherentTarget(mc)
else any()
}

/** Gets a method from an `impl` block that matches the method call `mc`. */
pragma[nomagic]
private Function getMethodFromImpl(MethodCall mc) {
exists(Impl impl |
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
result = getMethodSuccessor(impl, mc.getMethodName())
exists(Impl impl, string name |
methodCallHasImplCandidate(mc, impl) and
name = mc.getMethodName() and
result = getMethodSuccessor(impl, name)
|
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
result = getMethodSuccessor(impl, mc.getMethodName())
not methodResolutionDependsOnArgument(impl, _, _, _, _, _)
or
exists(int pos, TypePath path, Type type |
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
inferType(mc.getPositionalArgument(pos), path) = type
)
)
Expand All @@ -1356,22 +1385,6 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
}

/**
* Gets a method that the method call `mc` resolves to based on type inference,
* if any.
*/
private Function inferMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

cached
private module Cached {
private import codeql.rust.internal.CachedStages
Expand Down Expand Up @@ -1400,47 +1413,18 @@ private module Cached {
)
}

private predicate isInherentImplFunction(Function f) {
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

private predicate isTraitImplFunction(Function f) {
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
result = inferMethodCallTarget(mc) and
(if result.fromSource() then fromSource = true else fromSource = false) and
(
// prioritize inherent implementation methods first
isInherentImplFunction(result)
or
not isInherentImplFunction(inferMethodCallTarget(mc)) and
(
// then trait implementation methods
isTraitImplFunction(result)
or
not isTraitImplFunction(inferMethodCallTarget(mc)) and
(
// then trait methods with default implementations
result.hasBody()
or
// and finally trait methods without default implementations
not inferMethodCallTarget(mc).hasBody()
)
)
)
}

/** Gets a method that the method call `mc` resolves to, if any. */
cached
Function resolveMethodCallTarget(MethodCall mc) {
// Functions in source code also gets extracted as library code, due to
// this duplication we prioritize functions from source code.
result = resolveMethodCallTargetFrom(mc, true)
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
not exists(resolveMethodCallTargetFrom(mc, true)) and
result = resolveMethodCallTargetFrom(mc, false)
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

pragma[inline]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
multipleCallTargets
| main.rs:362:14:362:30 | ... .lt(...) |
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ multipleCallTargets
| test.rs:168:26:168:111 | ...::_print(...) |
| test.rs:178:30:178:68 | ...::_print(...) |
| test.rs:187:26:187:105 | ...::_print(...) |
| test.rs:228:22:228:72 | ... .read_to_string(...) |
| test.rs:482:22:482:50 | file.read_to_end(...) |
| test.rs:488:22:488:53 | file.read_to_string(...) |
| test.rs:609:18:609:38 | ...::_print(...) |
| test.rs:614:18:614:45 | ...::_print(...) |
| test.rs:618:25:618:49 | address.to_socket_addrs() |
Expand Down
2 changes: 1 addition & 1 deletion rust/ql/test/library-tests/dataflow/sources/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ fn test_io_stdin() -> std::io::Result<()> {
{
let mut buffer = Vec::<u8>::new();
let _bytes = std::io::stdin().read_to_end(&mut buffer)?; // $ Alert[rust/summary/taint-sources]
sink(&buffer); // $ MISSING: hasTaintFlow
sink(&buffer); // $ hasTaintFlow -- @hvitved: works in CI, but not for me locally
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's weird. I've written an issue to look into this, my suspicion is that when we finish updating all the models and then generalize what we can to trait models, the platform specific behaviours are going to go away.

}

{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:1963:13:1963:31 | ...::from(...) |
| main.rs:1964:13:1964:31 | ...::from(...) |
| main.rs:1965:13:1965:31 | ...::from(...) |
| main.rs:1970:13:1970:31 | ...::from(...) |
| main.rs:1971:13:1971:31 | ...::from(...) |
| main.rs:1972:13:1972:31 | ...::from(...) |
| main.rs:2006:21:2006:43 | ...::from(...) |
| main.rs:2032:13:2032:31 | ...::from(...) |
| main.rs:2033:13:2033:31 | ...::from(...) |
| main.rs:2034:13:2034:31 | ...::from(...) |
| main.rs:2040:13:2040:31 | ...::from(...) |
| main.rs:2041:13:2041:31 | ...::from(...) |
| main.rs:2042:13:2042:31 | ...::from(...) |
| main.rs:2078:21:2078:43 | ...::from(...) |
89 changes: 81 additions & 8 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,12 @@ mod impl_overlap {
impl OverlappingTrait for S1 {
// <S1_as_OverlappingTrait>::common_method
fn common_method(self) -> S1 {
panic!("not called");
S1
}

// <S1_as_OverlappingTrait>::common_method_2
fn common_method_2(self, s1: S1) -> S1 {
panic!("not called");
S1
}
}

Expand All @@ -427,10 +427,78 @@ mod impl_overlap {
}
}

struct S2<T2>(T2);

impl S2<i32> {
// S2<i32>::common_method
fn common_method(self) -> S1 {
S1
}

// S2<i32>::common_method
fn common_method_2(self) -> S1 {
S1
}
}

impl OverlappingTrait for S2<i32> {
// <S2<i32>_as_OverlappingTrait>::common_method
fn common_method(self) -> S1 {
S1
}

// <S2<i32>_as_OverlappingTrait>::common_method_2
fn common_method_2(self, s1: S1) -> S1 {
S1
}
}

impl OverlappingTrait for S2<S1> {
// <S2<S1>_as_OverlappingTrait>::common_method
fn common_method(self) -> S1 {
S1
}

// <S2<S1>_as_OverlappingTrait>::common_method_2
fn common_method_2(self, s1: S1) -> S1 {
S1
}
}

#[derive(Debug)]
struct S3<T3>(T3);

trait OverlappingTrait2<T> {
fn m(&self, x: &T) -> &Self;
}

impl<T> OverlappingTrait2<T> for S3<T> {
// <S3<T>_as_OverlappingTrait2<T>>::m
fn m(&self, x: &T) -> &Self {
self
}
}

impl<T> S3<T> {
// S3<T>::m
fn m(&self, x: T) -> &Self {
self
}
}

pub fn f() {
let x = S1;
println!("{:?}", x.common_method()); // $ method=S1::common_method
println!("{:?}", x.common_method_2()); // $ method=S1::common_method_2

let y = S2(S1);
println!("{:?}", y.common_method()); // $ method=<S2<S1>_as_OverlappingTrait>::common_method

let z = S2(0);
println!("{:?}", z.common_method()); // $ method=S2<i32>::common_method

let w = S3(S1);
println!("{:?}", w.m(x)); // $ method=S3<T>::m
}
}

Expand Down Expand Up @@ -1959,22 +2027,25 @@ mod loops {
for s in &mut strings1 {} // $ MISSING: type=s:&T.str
for s in strings1 {} // $ type=s:str

let strings2 = [ // $ type=strings2:[T;...].String
let strings2 = // $ type=strings2:[T;...].String
[
String::from("foo"),
String::from("bar"),
String::from("baz"),
];
for s in strings2 {} // $ type=s:String

let strings3 = &[ // $ type=strings3:&T.[T;...].String
let strings3 = // $ type=strings3:&T.[T;...].String
&[
String::from("foo"),
String::from("bar"),
String::from("baz"),
];
for s in strings3 {} // $ MISSING: type=s:String

let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ MISSING: type=callables:[T;...].MyCallable; 3
for c in callables // $ type=c:MyCallable
for c // $ type=c:MyCallable
in callables
{
let result = c.call(); // $ type=result:i64 method=call
}
Expand All @@ -1986,7 +2057,8 @@ mod loops {
let range = 0..10; // $ MISSING: type=range:Range type=range:Idx.i32
for i in range {} // $ MISSING: type=i:i32

let range1 = std::ops::Range { // $ type=range1:Range type=range1:Idx.u16
let range1 = // $ type=range1:Range type=range1:Idx.u16
std::ops::Range {
start: 0u16,
end: 10u16,
};
Expand Down Expand Up @@ -2031,10 +2103,11 @@ mod loops {
// while loops

let mut a: i64 = 0; // $ type=a:i64
while a < 10 // $ method=lt type=a:i64
#[rustfmt::skip]
let _ = while a < 10 // $ method=lt type=a:i64
{
a += 1; // $ type=a:i64 method=add_assign
}
};
}
}

Expand Down
Loading
Loading