diff --git a/crates/hir-ty/src/chalk_db.rs b/crates/hir-ty/src/chalk_db.rs index cd799c03ddf7..22b96b55cbb9 100644 --- a/crates/hir-ty/src/chalk_db.rs +++ b/crates/hir-ty/src/chalk_db.rs @@ -259,7 +259,7 @@ impl chalk_solve::RustIrDatabase for ChalkContext<'_> { } fn well_known_trait_id( &self, - well_known_trait: rust_ir::WellKnownTrait, + well_known_trait: WellKnownTrait, ) -> Option> { let lang_attr = lang_item_from_well_known_trait(well_known_trait); let trait_ = lang_attr.resolve_trait(self.db, self.krate)?; diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs index f0989d9de91f..f210dd8799f9 100644 --- a/crates/hir-ty/src/display.rs +++ b/crates/hir-ty/src/display.rs @@ -1463,6 +1463,8 @@ impl HirDisplay for Ty { } if f.closure_style == ClosureStyle::RANotation || !sig.ret().is_unit() { write!(f, " -> ")?; + // FIXME: We display `AsyncFn` as `-> impl Future`, but this is hard to fix because + // we don't have a trait environment here, required to normalize `::Output`. sig.ret().hir_fmt(f)?; } } else { diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 800897c6fc3a..bd57ca891620 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -38,7 +38,7 @@ use crate::{ infer::{BreakableKind, CoerceMany, Diverges, coerce::CoerceNever}, make_binders, mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem}, - to_chalk_trait_id, + to_assoc_type_id, to_chalk_trait_id, traits::FnTrait, utils::{self, elaborate_clause_supertraits}, }; @@ -245,7 +245,7 @@ impl InferenceContext<'_> { } fn deduce_closure_kind_from_predicate_clauses( - &self, + &mut self, expected_ty: &Ty, clauses: impl DoubleEndedIterator, closure_kind: ClosureKind, @@ -378,7 +378,7 @@ impl InferenceContext<'_> { } fn deduce_sig_from_projection( - &self, + &mut self, closure_kind: ClosureKind, projection_ty: &ProjectionTy, projected_ty: &Ty, @@ -392,13 +392,16 @@ impl InferenceContext<'_> { // For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits, // for closures and async closures, respectively. - match closure_kind { - ClosureKind::Closure | ClosureKind::Async - if self.fn_trait_kind_from_trait_id(trait_).is_some() => - { - self.extract_sig_from_projection(projection_ty, projected_ty) - } - _ => None, + let fn_trait_kind = self.fn_trait_kind_from_trait_id(trait_)?; + if !matches!(closure_kind, ClosureKind::Closure | ClosureKind::Async) { + return None; + } + if fn_trait_kind.is_async() { + // If the expected trait is `AsyncFn(...) -> X`, we don't know what the return type is, + // but we do know it must implement `Future`. + self.extract_async_fn_sig_from_projection(projection_ty, projected_ty) + } else { + self.extract_sig_from_projection(projection_ty, projected_ty) } } @@ -424,6 +427,39 @@ impl InferenceContext<'_> { ))) } + fn extract_async_fn_sig_from_projection( + &mut self, + projection_ty: &ProjectionTy, + projected_ty: &Ty, + ) -> Option> { + let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner); + + let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else { + return None; + }; + + let ret_param_future_output = projected_ty; + let ret_param_future = self.table.new_type_var(); + let future_output = + LangItem::FutureOutput.resolve_type_alias(self.db, self.resolver.krate())?; + let future_projection = crate::AliasTy::Projection(crate::ProjectionTy { + associated_ty_id: to_assoc_type_id(future_output), + substitution: Substitution::from1(Interner, ret_param_future.clone()), + }); + self.table.register_obligation( + crate::AliasEq { alias: future_projection, ty: ret_param_future_output.clone() } + .cast(Interner), + ); + + Some(FnSubst(Substitution::from_iter( + Interner, + input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new( + Interner, + chalk_ir::GenericArgData::Ty(ret_param_future), + ))), + ))) + } + fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option { FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?) } diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 2b527a4ae12e..e5d1fbe9defe 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -4903,3 +4903,30 @@ fn main() { "#]], ); } + +#[test] +fn async_fn_return_type() { + check_infer( + r#" +//- minicore: async_fn +fn foo R, R>(_: F) -> R { + loop {} +} + +fn main() { + foo(async move || ()); +} + "#, + expect![[r#" + 29..30 '_': F + 40..55 '{ loop {} }': R + 46..53 'loop {}': ! + 51..53 '{}': () + 67..97 '{ ...()); }': () + 73..76 'foo': fn foo impl Future, ()>(impl AsyncFn() -> impl Future) + 73..94 'foo(as...|| ())': () + 77..93 'async ... || ()': impl AsyncFn() -> impl Future + 91..93 '()': () + "#]], + ); +} diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index f9f8776cff7c..7414b4fc6070 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -291,4 +291,9 @@ impl FnTrait { pub fn get_id(self, db: &dyn HirDatabase, krate: Crate) -> Option { self.lang_item().resolve_trait(db, krate) } + + #[inline] + pub(crate) fn is_async(self) -> bool { + matches!(self, FnTrait::AsyncFn | FnTrait::AsyncFnMut | FnTrait::AsyncFnOnce) + } }