Skip to content

Commit

Permalink
Replacing Self in generic arguements of fns with associated types. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Jan 14, 2025
1 parent 3fea88e commit d351a54
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 30 deletions.
16 changes: 8 additions & 8 deletions crates/cairo-lang-semantic/src/expr/inference/infers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub trait InferenceEmbeddings {
concrete_trait_function: ConcreteTraitGenericFunctionId,
lookup_context: &ImplLookupContext,
stable_ptr: Option<SyntaxStablePtrId>,
) -> GenericFunctionId;
) -> ImplGenericFunctionId;
fn infer_trait_type(
&mut self,
concrete_trait_type: ConcreteTraitTypeId,
Expand Down Expand Up @@ -402,8 +402,11 @@ impl InferenceEmbeddings for Inference<'_> {
lookup_context: &ImplLookupContext,
stable_ptr: Option<SyntaxStablePtrId>,
) -> InferenceResult<FunctionId> {
let generic_function =
self.infer_trait_generic_function(concrete_trait_function, lookup_context, stable_ptr);
let generic_function = GenericFunctionId::Impl(self.infer_trait_generic_function(
concrete_trait_function,
lookup_context,
stable_ptr,
));
self.infer_generic_function(generic_function, lookup_context, stable_ptr)
}

Expand All @@ -430,16 +433,13 @@ impl InferenceEmbeddings for Inference<'_> {
concrete_trait_function: ConcreteTraitGenericFunctionId,
lookup_context: &ImplLookupContext,
stable_ptr: Option<SyntaxStablePtrId>,
) -> GenericFunctionId {
) -> ImplGenericFunctionId {
let impl_id = self.new_impl_var(
concrete_trait_function.concrete_trait(self.db),
stable_ptr,
lookup_context.clone(),
);
GenericFunctionId::Impl(ImplGenericFunctionId {
impl_id,
function: concrete_trait_function.trait_function(self.db),
})
ImplGenericFunctionId { impl_id, function: concrete_trait_function.trait_function(self.db) }
}

/// Infers the impl to be substituted instead of a trait for a given trait type.
Expand Down
22 changes: 14 additions & 8 deletions crates/cairo-lang-semantic/src/items/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1942,12 +1942,6 @@ pub fn infer_impl_by_self(
.intern(ctx.db);
let trait_func_generic_params =
ctx.db.concrete_trait_function_generic_params(concrete_trait_function_id).unwrap();
let generic_args = ctx.resolver.resolve_generic_args(
ctx.diagnostics,
&trait_func_generic_params,
&generic_args_syntax.unwrap_or_default(),
stable_ptr,
)?;

let impl_lookup_context = ctx.resolver.impl_lookup_context();
let inference = &mut ctx.resolver.inference();
Expand All @@ -1956,10 +1950,22 @@ pub fn infer_impl_by_self(
&impl_lookup_context,
Some(stable_ptr),
);
let generic_args = ctx.resolver.resolve_generic_args(
ctx.diagnostics,
GenericSubstitution::from_impl(generic_function.impl_id),
&trait_func_generic_params,
&generic_args_syntax.unwrap_or_default(),
stable_ptr,
)?;

Ok((
FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
.intern(ctx.db),
FunctionLongId {
function: ConcreteFunction {
generic_function: GenericFunctionId::Impl(generic_function),
generic_args,
},
}
.intern(ctx.db),
n_snapshots,
))
}
Expand Down
58 changes: 44 additions & 14 deletions crates/cairo-lang-semantic/src/resolve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -869,11 +869,12 @@ impl<'db> Resolver<'db> {
}
}
let impl_lookup_context = self.impl_lookup_context();
let generic_function = self.inference().infer_trait_generic_function(
concrete_trait_function,
&impl_lookup_context,
Some(identifier_stable_ptr),
);
let generic_function =
GenericFunctionId::Impl(self.inference().infer_trait_generic_function(
concrete_trait_function,
&impl_lookup_context,
Some(identifier_stable_ptr),
));

Ok(ResolvedConcreteItem::Function(self.specialize_function(
diagnostics,
Expand Down Expand Up @@ -1076,6 +1077,7 @@ impl<'db> Resolver<'db> {
self.db.module_type_alias_generic_params(module_type_alias_id)?;
let generic_args = self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&generic_params,
&generic_args_syntax.unwrap_or_default(),
identifier.stable_ptr().untyped(),
Expand All @@ -1090,6 +1092,7 @@ impl<'db> Resolver<'db> {
let generic_params = self.db.impl_alias_generic_params(impl_alias_id)?;
let generic_args = self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&generic_params,
&generic_args_syntax.unwrap_or_default(),
identifier.stable_ptr().untyped(),
Expand Down Expand Up @@ -1364,8 +1367,13 @@ impl<'db> Resolver<'db> {
.db
.trait_generic_params(trait_id)
.map_err(|_| diagnostics.report(stable_ptr, UnknownTrait))?;
let generic_args =
self.resolve_generic_args(diagnostics, &generic_params, generic_args, stable_ptr)?;
let generic_args = self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&generic_params,
generic_args,
stable_ptr,
)?;

Ok(ConcreteTraitLongId { trait_id, generic_args }.intern(self.db))
}
Expand All @@ -1383,8 +1391,13 @@ impl<'db> Resolver<'db> {
.db
.impl_def_generic_params(impl_def_id)
.map_err(|_| diagnostics.report(stable_ptr, UnknownImpl))?;
let generic_args =
self.resolve_generic_args(diagnostics, &generic_params, generic_args, stable_ptr)?;
let generic_args = self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&generic_params,
generic_args,
stable_ptr,
)?;

Ok(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
}
Expand All @@ -1401,6 +1414,7 @@ impl<'db> Resolver<'db> {
enum_id: variant_id.enum_id(self.db.upcast()),
generic_args: self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&self.db.enum_generic_params(variant_id.enum_id(self.db.upcast()))?,
generic_args,
stable_ptr,
Expand All @@ -1423,8 +1437,18 @@ impl<'db> Resolver<'db> {
) -> Maybe<FunctionId> {
// TODO(lior): Should we report diagnostic if `impl_def_generic_params` failed?
let generic_params: Vec<_> = generic_function.generic_params(self.db)?;
let generic_args =
self.resolve_generic_args(diagnostics, &generic_params, generic_args, stable_ptr)?;
let substitution = if let GenericFunctionId::Impl(id) = generic_function {
GenericSubstitution::from_impl(id.impl_id)
} else {
GenericSubstitution::default()
};
let generic_args = self.resolve_generic_args(
diagnostics,
substitution,
&generic_params,
generic_args,
stable_ptr,
)?;

Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
.intern(self.db))
Expand All @@ -1442,8 +1466,13 @@ impl<'db> Resolver<'db> {
.db
.generic_type_generic_params(generic_type)
.map_err(|_| diagnostics.report(stable_ptr, UnknownType))?;
let generic_args =
self.resolve_generic_args(diagnostics, &generic_params, generic_args, stable_ptr)?;
let generic_args = self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&generic_params,
generic_args,
stable_ptr,
)?;

Ok(TypeLongId::Concrete(ConcreteTypeId::new(self.db, generic_type, generic_args))
.intern(self.db))
Expand Down Expand Up @@ -1471,11 +1500,11 @@ impl<'db> Resolver<'db> {
pub fn resolve_generic_args(
&mut self,
diagnostics: &mut SemanticDiagnostics,
mut substitution: GenericSubstitution,
generic_params: &[GenericParam],
generic_args_syntax: &[ast::GenericArg],
stable_ptr: SyntaxStablePtrId,
) -> Maybe<Vec<GenericArgumentId>> {
let mut substitution = GenericSubstitution::default();
let mut resolved_args = vec![];
let arg_syntax_per_param =
self.get_arg_syntax_per_param(diagnostics, generic_params, generic_args_syntax)?;
Expand Down Expand Up @@ -1901,6 +1930,7 @@ impl<'db> Resolver<'db> {
}
let resolved_args = self.resolve_generic_args(
diagnostics,
GenericSubstitution::default(),
&generic_params,
current_segment_generic_args,
segment_stable_ptr,
Expand Down
16 changes: 16 additions & 0 deletions tests/bug_samples/issue7060.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
pub trait WithAssociated<T> {
type S;
fn foo<+Drop<T>, +Drop<Self::S>>(self: T, s: Self::S) -> Self::S {
s
}
}

impl Impl<T> of WithAssociated<T> {
type S = u8;
}

#[test]
fn test_associated_type_usage() {
WithAssociated::foo(4, 5_u8);
4.foo(5_u8);
}
1 change: 1 addition & 0 deletions tests/bug_samples/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod issue6920;
mod issue6968;
mod issue7031;
mod issue7038;
mod issue7060;
mod loop_break_in_match;
mod loop_only_change;
mod partial_param_local;
Expand Down

0 comments on commit d351a54

Please sign in to comment.