Skip to content

Commit

Permalink
Instantiate closure-like bounds with placeholders to deal with binder…
Browse files Browse the repository at this point in the history
…s correctly
  • Loading branch information
compiler-errors committed Mar 26, 2024
1 parent 519d892 commit 24482c8
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 82 deletions.
150 changes: 82 additions & 68 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,17 +678,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
fn_host_effect: ty::Const<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
debug!(?obligation, "confirm_fn_pointer_candidate");
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let tcx = self.tcx();

let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
// FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
// but we do not currently. Luckily, such a bound is not
// particularly useful, so we don't expect users to write
// them often.
return Err(SelectionError::Unimplemented);
};

let sig = self_ty.fn_sig(tcx);
let trait_ref = closure_trait_ref_and_return_type(
tcx,
Expand All @@ -700,7 +693,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
)
.map_bound(|(trait_ref, _)| trait_ref);

let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
let mut nested = self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
trait_ref,
)?;
let cause = obligation.derived_cause(BuiltinDerivedObligation);

// Confirm the `type Output: Sized;` bound that is present on `FnOnce`
Expand Down Expand Up @@ -748,10 +746,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -760,23 +756,19 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

let coroutine_sig = args.as_coroutine().sig();

// NOTE: The self-type is a coroutine type and hence is
// in fact unparameterized (or at least does not reference any
// regions bound in the obligation).
let self_ty = obligation
.predicate
.self_ty()
.no_bound_vars()
.expect("unboxed closure type should not capture bound vars from the predicate");

let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
self_ty,
coroutine_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "coroutine candidate obligations");

Ok(nested)
Expand All @@ -786,10 +778,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -801,11 +791,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
self_ty,
coroutine_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "future candidate obligations");

Ok(nested)
Expand All @@ -815,10 +810,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -830,11 +823,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
self_ty,
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
Expand All @@ -844,10 +842,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -859,11 +855,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
self_ty,
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
Expand All @@ -874,14 +875,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on closure types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let trait_ref = match *self_ty.kind() {
ty::Closure(_, args) => {
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
}
ty::Closure(..) => self.closure_trait_ref_unnormalized(
self_ty,
obligation.predicate.def_id(),
self.tcx().consts.true_,
),
ty::CoroutineClosure(_, args) => {
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
ty::TraitRef::new(
Expand All @@ -896,16 +898,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
};

self.confirm_poly_trait_refs(obligation, trait_ref)
self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
trait_ref,
)
}

#[instrument(skip(self), level = "debug")]
fn confirm_async_closure_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let tcx = self.tcx();
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());

let mut nested = vec![];
let (trait_ref, kind_ty) = match *self_ty.kind() {
Expand Down Expand Up @@ -972,7 +981,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
_ => bug!("expected callable type for AsyncFn candidate"),
};

nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
nested.extend(self.equate_trait_refs(
&obligation.cause,
obligation.param_env,
placeholder_predicate.trait_ref,
trait_ref,
)?);

let goal_kind =
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
Expand Down Expand Up @@ -1025,42 +1039,42 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
/// selection of the impl. Therefore, if there is a mismatch, we
/// report an error to the user.
#[instrument(skip(self), level = "trace")]
fn confirm_poly_trait_refs(
fn equate_trait_refs(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
cause: &ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
obligation_trait_ref: ty::TraitRef<'tcx>,
found_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let obligation_trait_ref =
self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
cause.span,
HigherRankedType,
self_ty_trait_ref,
found_trait_ref,
);
// Normalize the obligation and expected trait refs together, because why not
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
ensure_sufficient_stack(|| {
normalize_with_depth(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
(obligation_trait_ref, self_ty_trait_ref),
param_env,
cause.clone(),
0,
(obligation_trait_ref, found_trait_ref),
)
});

// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
self.infcx
.at(&obligation.cause, obligation.param_env)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.at(&cause, param_env)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
.map(|InferOk { mut obligations, .. }| {
obligations.extend(nested);
obligations
})
.map_err(|terr| {
SignatureMismatch(Box::new(SignatureMismatchData {
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
found_trait_ref: ty::Binder::dummy(found_trait_ref),
terr,
}))
})
Expand Down
20 changes: 6 additions & 14 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2679,26 +2679,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
#[instrument(skip(self), level = "debug")]
fn closure_trait_ref_unnormalized(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
args: GenericArgsRef<'tcx>,
self_ty: Ty<'tcx>,
fn_trait_def_id: DefId,
fn_host_effect: ty::Const<'tcx>,
) -> ty::PolyTraitRef<'tcx> {
let ty::Closure(_, args) = *self_ty.kind() else {
bug!("expected closure, found {self_ty}");
};
let closure_sig = args.as_closure().sig();

debug!(?closure_sig);

// NOTE: The self-type is an unboxed closure type and hence is
// in fact unparameterized (or at least does not reference any
// regions bound in the obligation).
let self_ty = obligation
.predicate
.self_ty()
.no_bound_vars()
.expect("unboxed closure type should not capture bound vars from the predicate");

closure_trait_ref_and_return_type(
self.tcx(),
obligation.predicate.def_id(),
fn_trait_def_id,
self_ty,
closure_sig,
util::TupleArgumentsFlag::No,
Expand Down
51 changes: 51 additions & 0 deletions tests/ui/higher-ranked/builtin-closure-like-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//@ edition:2024
//@ compile-flags: -Zunstable-options
//@ revisions: current next
//@[next] compile-flags: -Znext-solver

#![feature(unboxed_closures, gen_blocks)]

trait Dispatch {
fn dispatch(self);
}

struct Fut<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Fut<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Future,
{
fn dispatch(self) {
(self.0)(&());
}
}

struct Gen<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Gen<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Iterator,
{
fn dispatch(self) {
(self.0)(&());
}
}

struct Closure<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Closure<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Fn<(&'a (),)>,
{
fn dispatch(self) {
(self.0)(&())(&());
}
}

fn main() {
async fn foo(_: &()) {}
Fut(foo).dispatch();

gen fn bar(_: &()) {}
Gen(bar).dispatch();

fn uwu<'a>(x: &'a ()) -> impl Fn(&'a ()) { |_| {} }
Closure(uwu).dispatch();
}

0 comments on commit 24482c8

Please sign in to comment.