Skip to content

Commit 52a26ad

Browse files
committed
Rust: Rework call disambiguation logic
1 parent 4818311 commit 52a26ad

File tree

12 files changed

+443
-225
lines changed

12 files changed

+443
-225
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,32 +77,23 @@ pragma[nomagic]
7777
private predicate implHasSibling(ImplItemNode impl, Trait trait) { implSiblings(trait, impl, _) }
7878

7979
/**
80-
* Holds if type parameter `tp` of `trait` occurs in the function `f` with the name
81-
* `functionName` at position `pos` and path `path`.
82-
*
83-
* Note that `pos` can also be the special `return` position, which is sometimes
84-
* needed to disambiguate associated function calls like `Default::default()`
85-
* (in this case, `tp` is the special `Self` type parameter).
80+
* Holds if `f` is a function declared inside `trait`, and the type of `f` at
81+
* `pos` and `path` is `traitTp`, which is a type parameter of `trait`.
8682
*/
87-
bindingset[trait]
88-
pragma[inline_late]
83+
pragma[nomagic]
8984
predicate traitTypeParameterOccurrence(
9085
TraitItemNode trait, Function f, string functionName, FunctionPosition pos, TypePath path,
91-
TypeParameter tp
86+
TypeParameter traitTp
9287
) {
93-
f = trait.getASuccessor(functionName) and
94-
tp = getAssocFunctionTypeAt(f, trait, pos, path) and
95-
tp = trait.(TraitTypeAbstraction).getATypeParameter()
88+
f = trait.getAssocItem(functionName) and
89+
traitTp = getAssocFunctionTypeInclNonMethodSelfAt(f, trait, pos, path) and
90+
traitTp = trait.(TraitTypeAbstraction).getATypeParameter()
9691
}
9792

98-
/**
99-
* Holds if resolving the function `f` in `impl` with the name `functionName`
100-
* requires inspecting the type of applied _arguments_ at position `pos` in
101-
* order to determine whether it is the correct resolution.
102-
*/
10393
pragma[nomagic]
104-
predicate functionResolutionDependsOnArgument(
105-
ImplItemNode impl, Function f, FunctionPosition pos, TypePath path, Type type
94+
private predicate functionResolutionDependsOnArgumentCand(
95+
ImplItemNode impl, Function f, string functionName, TypeParameter traitTp, FunctionPosition pos,
96+
TypePath path
10697
) {
10798
/*
10899
* As seen in the example below, when an implementation has a sibling for a
@@ -129,11 +120,113 @@ predicate functionResolutionDependsOnArgument(
129120
* method. In that case we will still resolve several methods.
130121
*/
131122

132-
exists(TraitItemNode trait, string functionName |
123+
exists(TraitItemNode trait |
133124
implHasSibling(impl, trait) and
134-
traitTypeParameterOccurrence(trait, _, functionName, pos, path, _) and
135-
type = getAssocFunctionTypeAt(f, impl, pos, path) and
125+
traitTypeParameterOccurrence(trait, _, functionName, pos, path, traitTp) and
136126
f = impl.getASuccessor(functionName) and
127+
not pos.isSelf()
128+
)
129+
}
130+
131+
private predicate functionResolutionDependsOnPositionalArgumentCand(
132+
ImplItemNode impl, Function f, string functionName, TypeParameter traitTp
133+
) {
134+
exists(FunctionPosition pos |
135+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, _) and
136+
pos.isPosition()
137+
)
138+
}
139+
140+
pragma[nomagic]
141+
private Type getAssocFunctionNonTypeParameterTypeAt(
142+
ImplItemNode impl, Function f, FunctionPosition pos, TypePath path
143+
) {
144+
result = getAssocFunctionTypeInclNonMethodSelfAt(f, impl, pos, path) and
145+
not result instanceof TypeParameter
146+
}
147+
148+
/**
149+
* Holds if `f` inside `impl` has a sibling implementation inside `sibling`, where
150+
* those two implementations agree on the instantiation of `traitTp`, which occurs
151+
* in a positional position inside `f`.
152+
*/
153+
pragma[nomagic]
154+
private predicate hasEquivalentPositionalSibling(
155+
ImplItemNode impl, ImplItemNode sibling, Function f, TypeParameter traitTp
156+
) {
157+
exists(string functionName, FunctionPosition pos, TypePath path |
158+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, path) and
137159
pos.isPosition()
160+
|
161+
exists(Function f1 |
162+
implSiblings(_, impl, sibling) and
163+
f1 = sibling.getASuccessor(functionName)
164+
|
165+
forall(TypePath path0, Type t |
166+
t = getAssocFunctionNonTypeParameterTypeAt(impl, f, pos, path0) and
167+
path = path0.getAPrefixOrSelf()
168+
|
169+
t = getAssocFunctionNonTypeParameterTypeAt(sibling, f1, pos, path0)
170+
) and
171+
forall(TypePath path0, Type t |
172+
t = getAssocFunctionNonTypeParameterTypeAt(sibling, f1, pos, path0) and
173+
path = path0.getAPrefixOrSelf()
174+
|
175+
t = getAssocFunctionNonTypeParameterTypeAt(impl, f, pos, path0)
176+
)
177+
)
178+
)
179+
}
180+
181+
/**
182+
* Holds if resolving the function `f` in `impl` requires inspecting the type
183+
* of applied _arguments_ or possibly knowing the return type.
184+
*
185+
* `traitTp` is a type parameter of the trait being implemented by `impl`, and
186+
* we need to check that the type of `f` corresponding to `traitTp` is satisfied
187+
* at any one of the positions `pos` in which that type occurs in `f`.
188+
*
189+
* Type parameters that only occur in return positions are only included when
190+
* all other type parameters that occur in a positional position are insufficient
191+
* to disambiguate.
192+
*
193+
* Example:
194+
*
195+
* ```rust
196+
* trait Trait1<T1> {
197+
* fn f(self, x: T1) -> T1;
198+
* }
199+
*
200+
* impl Trait1<i32> for i32 {
201+
* fn f(self, x: i32) -> i32 { 0 } // f1
202+
* }
203+
*
204+
* impl Trait1<i64> for i32 {
205+
* fn f(self, x: i64) -> i64 { 0 } // f2
206+
* }
207+
* ```
208+
*
209+
* The type for `T1` above occurs in both a positional position and a return position
210+
* in `f`, so both may be used to disambiguate between `f1` and `f2`. That is, `f(0i32)`
211+
* is sufficient to resolve to `f1`, and so is `let y: i64 = f(Default::default())`.
212+
*/
213+
pragma[nomagic]
214+
predicate functionResolutionDependsOnArgument(
215+
ImplItemNode impl, Function f, TypeParameter traitTp, FunctionPosition pos
216+
) {
217+
exists(string functionName |
218+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, _)
219+
|
220+
if functionResolutionDependsOnPositionalArgumentCand(impl, f, functionName, traitTp)
221+
then any()
222+
else
223+
exists(ImplItemNode sibling |
224+
implSiblings(_, impl, sibling) and
225+
forall(TypeParameter otherTraitTp |
226+
functionResolutionDependsOnPositionalArgumentCand(impl, f, functionName, otherTraitTp)
227+
|
228+
hasEquivalentPositionalSibling(impl, sibling, f, otherTraitTp)
229+
)
230+
)
138231
)
139232
}

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 91 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ private newtype TAssocFunctionType =
8282
// through `i`. This ensures that `parent` is either a supertrait of `i` or
8383
// `i` in an `impl` block implementing `parent`.
8484
(parent = i or BaseTypes::rootTypesSatisfaction(_, TTrait(parent), i, _, _)) and
85-
exists(pos.getTypeMention(f))
85+
// We always include the `self` position, even for non-methods, where it is used
86+
// to match type qualifiers against the `impl` or trait type, such as in `Vec::new`.
87+
(exists(pos.getTypeMention(f)) or pos.isSelf())
8688
}
8789

8890
bindingset[abs, constraint, tp]
@@ -116,6 +118,22 @@ Type getAssocFunctionTypeAt(Function f, ImplOrTraitItemNode i, FunctionPosition
116118
)
117119
}
118120

121+
/**
122+
* Same as `getAssocFunctionTypeAt`, but also includes types at the `self` position
123+
* for non-methods.
124+
*/
125+
pragma[nomagic]
126+
Type getAssocFunctionTypeInclNonMethodSelfAt(
127+
Function f, ImplOrTraitItemNode i, FunctionPosition pos, TypePath path
128+
) {
129+
result = getAssocFunctionTypeAt(f, i, pos, path)
130+
or
131+
f = i.getASuccessor(_) and
132+
not f.hasSelfParam() and
133+
pos.isSelf() and
134+
result = resolveImplOrTraitType(i, path)
135+
}
136+
119137
/**
120138
* The type of an associated function at a given position, when its implicit
121139
* `Self` type parameter is specialized to a given trait or `impl` block.
@@ -174,7 +192,7 @@ class AssocFunctionType extends MkAssocFunctionType {
174192
Type getTypeAt(TypePath path) {
175193
exists(Function f, FunctionPosition pos, ImplOrTraitItemNode i, Type t |
176194
this.appliesTo(f, i, pos) and
177-
t = getAssocFunctionTypeAt(f, i, pos, path)
195+
t = getAssocFunctionTypeInclNonMethodSelfAt(f, i, pos, path)
178196
|
179197
not t instanceof SelfTypeParameter and
180198
result = t
@@ -184,9 +202,12 @@ class AssocFunctionType extends MkAssocFunctionType {
184202
}
185203

186204
private TypeMention getTypeMention() {
187-
exists(Function f, FunctionPosition pos |
188-
this.appliesTo(f, _, pos) and
205+
exists(Function f, ImplOrTraitItemNode i, FunctionPosition pos | this.appliesTo(f, i, pos) |
189206
result = pos.getTypeMention(f)
207+
or
208+
pos.isSelf() and
209+
not f.hasSelfParam() and
210+
result = [i.(Impl).getSelfTy().(TypeMention), i.(TypeMention)]
190211
)
191212
}
192213

@@ -294,10 +315,13 @@ module ArgIsInstantiationOf<
294315
*/
295316
signature module ArgsAreInstantiationsOfInputSig {
296317
/**
297-
* Holds if types need to be matched against the type `t` at position `pos` of
298-
* `f` inside `i`.
318+
* Holds if `f` implements (or is itself) a trait function with type parameter
319+
* `traitTp`, where we need to check that the type of `f` for `traitTp` is
320+
* satisfied.
321+
*
322+
* `pos` is one of the positions in `f` in which the relevant type occours.
299323
*/
300-
predicate toCheck(ImplOrTraitItemNode i, Function f, FunctionPosition pos, AssocFunctionType t);
324+
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos);
301325

302326
/** A call whose argument types are to be checked. */
303327
class Call {
@@ -318,23 +342,28 @@ signature module ArgsAreInstantiationsOfInputSig {
318342
*/
319343
module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
320344
pragma[nomagic]
321-
private predicate toCheckRanked(ImplOrTraitItemNode i, Function f, FunctionPosition pos, int rnk) {
322-
Input::toCheck(i, f, pos, _) and
323-
pos =
324-
rank[rnk + 1](FunctionPosition pos0, int j |
325-
Input::toCheck(i, f, pos0, _) and
326-
(
327-
j = pos0.asPosition()
328-
or
329-
pos0.isSelf() and j = -1
330-
or
331-
pos0.isReturn() and j = -2
332-
)
345+
private predicate toCheckRanked(
346+
ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos, int rnk
347+
) {
348+
Input::toCheck(i, f, traitTp, pos) and
349+
traitTp =
350+
rank[rnk + 1](TypeParameter traitTp0, int j |
351+
Input::toCheck(i, f, traitTp0, _) and
352+
j = getTypeParameterId(traitTp0)
333353
|
334-
pos0 order by j
354+
traitTp0 order by j
335355
)
336356
}
337357

358+
pragma[nomagic]
359+
private predicate toCheck(
360+
ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos,
361+
AssocFunctionType t
362+
) {
363+
Input::toCheck(i, f, traitTp, pos) and
364+
t.appliesTo(f, i, pos)
365+
}
366+
338367
private newtype TCallAndPos =
339368
MkCallAndPos(Input::Call call, FunctionPosition pos) { exists(call.getArgType(pos, _)) }
340369

@@ -356,36 +385,34 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
356385
string toString() { result = call.toString() + " [arg " + pos + "]" }
357386
}
358387

388+
pragma[nomagic]
389+
private predicate potentialInstantiationOf0(
390+
CallAndPos cp, Input::Call call, TypeParameter traitTp, FunctionPosition pos, Function f,
391+
TypeAbstraction abs, AssocFunctionType constraint
392+
) {
393+
cp = MkCallAndPos(call, pragma[only_bind_into](pos)) and
394+
call.hasTargetCand(abs, f) and
395+
toCheck(abs, f, traitTp, pragma[only_bind_into](pos), constraint)
396+
}
397+
359398
private module ArgIsInstantiationOfToIndexInput implements
360399
IsInstantiationOfInputSig<CallAndPos, AssocFunctionType>
361400
{
362-
pragma[nomagic]
363-
private predicate potentialInstantiationOf0(
364-
CallAndPos cp, Input::Call call, FunctionPosition pos, int rnk, Function f,
365-
TypeAbstraction abs, AssocFunctionType constraint
366-
) {
367-
cp = MkCallAndPos(call, pragma[only_bind_into](pos)) and
368-
call.hasTargetCand(abs, f) and
369-
toCheckRanked(abs, f, pragma[only_bind_into](pos), rnk) and
370-
Input::toCheck(abs, f, pragma[only_bind_into](pos), constraint)
371-
}
372-
373401
pragma[nomagic]
374402
predicate potentialInstantiationOf(
375403
CallAndPos cp, TypeAbstraction abs, AssocFunctionType constraint
376404
) {
377-
exists(Input::Call call, int rnk, Function f |
378-
potentialInstantiationOf0(cp, call, _, rnk, f, abs, constraint)
405+
exists(Input::Call call, TypeParameter traitTp, FunctionPosition pos, int rnk, Function f |
406+
potentialInstantiationOf0(cp, call, traitTp, pos, f, abs, constraint) and
407+
toCheckRanked(abs, f, traitTp, pos, rnk)
379408
|
380409
rnk = 0
381410
or
382411
argsAreInstantiationsOfToIndex(call, abs, f, rnk - 1)
383412
)
384413
}
385414

386-
predicate relevantConstraint(AssocFunctionType constraint) {
387-
Input::toCheck(_, _, _, constraint)
388-
}
415+
predicate relevantConstraint(AssocFunctionType constraint) { toCheck(_, _, _, _, constraint) }
389416
}
390417

391418
private module ArgIsInstantiationOfToIndex =
@@ -398,39 +425,63 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
398425
exists(FunctionPosition pos |
399426
ArgIsInstantiationOfToIndex::argIsInstantiationOf(MkCallAndPos(call, pos), i, _) and
400427
call.hasTargetCand(i, f) and
401-
toCheckRanked(i, f, pos, rnk)
428+
toCheckRanked(i, f, _, pos, rnk)
429+
|
430+
rnk = 0
431+
or
432+
argsAreInstantiationsOfToIndex(call, i, f, rnk - 1)
402433
)
403434
}
404435

405436
/**
406437
* Holds if all arguments of `call` have types that are instantiations of the
407438
* types of the corresponding parameters of `f` inside `i`.
439+
*
440+
* TODO: Check type parameter constraints as well.
408441
*/
409442
pragma[nomagic]
410443
predicate argsAreInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
411444
exists(int rnk |
412445
argsAreInstantiationsOfToIndex(call, i, f, rnk) and
413-
rnk = max(int r | toCheckRanked(i, f, _, r))
446+
rnk = max(int r | toCheckRanked(i, f, _, _, r))
414447
)
415448
}
416449

450+
private module ArgsAreNotInstantiationOfInput implements
451+
IsInstantiationOfInputSig<CallAndPos, AssocFunctionType>
452+
{
453+
pragma[nomagic]
454+
predicate potentialInstantiationOf(
455+
CallAndPos cp, TypeAbstraction abs, AssocFunctionType constraint
456+
) {
457+
potentialInstantiationOf0(cp, _, _, _, _, abs, constraint)
458+
}
459+
460+
predicate relevantConstraint(AssocFunctionType constraint) { toCheck(_, _, _, _, constraint) }
461+
}
462+
463+
private module ArgsAreNotInstantiationOf =
464+
ArgIsInstantiationOf<CallAndPos, ArgsAreNotInstantiationOfInput>;
465+
417466
pragma[nomagic]
418467
private predicate argsAreNotInstantiationsOf0(
419468
Input::Call call, FunctionPosition pos, ImplOrTraitItemNode i
420469
) {
421-
ArgIsInstantiationOfToIndex::argIsNotInstantiationOf(MkCallAndPos(call, pos), i, _, _)
470+
ArgsAreNotInstantiationOf::argIsNotInstantiationOf(MkCallAndPos(call, pos), i, _, _)
422471
}
423472

424473
/**
425474
* Holds if _some_ argument of `call` has a type that is not an instantiation of the
426475
* type of the corresponding parameter of `f` inside `i`.
476+
*
477+
* TODO: Check type parameter constraints as well.
427478
*/
428479
pragma[nomagic]
429480
predicate argsAreNotInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
430481
exists(FunctionPosition pos |
431482
argsAreNotInstantiationsOf0(call, pos, i) and
432483
call.hasTargetCand(i, f) and
433-
Input::toCheck(i, f, pos, _)
484+
Input::toCheck(i, f, _, pos)
434485
)
435486
}
436487
}

0 commit comments

Comments
 (0)