From 3520970228857eb307345d38e4ebd33595b6ba3b Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 28 Nov 2024 12:15:27 -0300 Subject: [PATCH] WIP: type-check trait default methods --- compiler/noirc_frontend/src/ast/expression.rs | 4 +- compiler/noirc_frontend/src/elaborator/mod.rs | 6 ++ .../noirc_frontend/src/elaborator/traits.rs | 37 ++++++---- .../noirc_frontend/src/elaborator/types.rs | 9 ++- .../src/hir/def_collector/dc_mod.rs | 2 +- .../noirc_frontend/src/hir_def/function.rs | 4 +- compiler/noirc_frontend/src/tests.rs | 2 +- compiler/noirc_frontend/src/tests/traits.rs | 67 ++++++++++++++++++- 8 files changed, 107 insertions(+), 24 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 2c8a9b6508d..b435675adcd 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -815,7 +815,7 @@ impl FunctionDefinition { is_unconstrained: bool, generics: &UnresolvedGenerics, parameters: &[(Ident, UnresolvedType)], - body: &BlockExpression, + body: BlockExpression, where_clause: &[UnresolvedTraitConstraint], return_type: &FunctionReturnType, ) -> FunctionDefinition { @@ -837,7 +837,7 @@ impl FunctionDefinition { visibility: ItemVisibility::Private, generics: generics.clone(), parameters: p, - body: body.clone(), + body, span: name.span(), where_clause: where_clause.to_vec(), return_type: return_type.clone(), diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 20d27fbc9ac..2a3cb79c87c 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -328,6 +328,12 @@ impl<'context> Elaborator<'context> { self.elaborate_functions(functions); } + for (trait_id, unresolved_trait) in items.traits { + self.current_trait = Some(trait_id); + self.elaborate_functions(unresolved_trait.fns_with_default_impl); + } + self.current_trait = None; + for impls in items.impls.into_values() { self.elaborate_impls(impls); } diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index e1be45927ca..0006d624288 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -28,6 +28,11 @@ impl<'context> Elaborator<'context> { self.recover_generics(|this| { this.current_trait = Some(*trait_id); + let the_trait = this.interner.get_trait(*trait_id); + let self_typevar = the_trait.self_type_typevar.clone(); + let self_type = Type::TypeVariable(self_typevar.clone()); + this.self_type = Some(self_type.clone()); + let resolved_generics = this.interner.get_trait(*trait_id).generics.clone(); this.add_existing_generics( &unresolved_trait.trait_def.generics, @@ -48,12 +53,15 @@ impl<'context> Elaborator<'context> { .add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id); } + this.interner.update_trait(*trait_id, |trait_def| { + trait_def.set_trait_bounds(resolved_trait_bounds); + trait_def.set_where_clause(where_clause); + }); + let methods = this.resolve_trait_methods(*trait_id, unresolved_trait); this.interner.update_trait(*trait_id, |trait_def| { trait_def.set_methods(methods); - trait_def.set_trait_bounds(resolved_trait_bounds); - trait_def.set_where_clause(where_clause); }); }); @@ -94,7 +102,7 @@ impl<'context> Elaborator<'context> { parameters, return_type, where_clause, - body: _, + body, is_unconstrained, visibility: _, is_comptime: _, @@ -102,8 +110,9 @@ impl<'context> Elaborator<'context> { { self.recover_generics(|this| { let the_trait = this.interner.get_trait(trait_id); + let the_trait_where_clause = the_trait.where_clause.clone(); + let the_trait_constraint = the_trait.as_constraint(the_trait.name.span()); let self_typevar = the_trait.self_type_typevar.clone(); - let self_type = Type::TypeVariable(self_typevar.clone()); let name_span = the_trait.name.span(); this.add_existing_generic( @@ -115,7 +124,6 @@ impl<'context> Elaborator<'context> { span: name_span, }, ); - this.self_type = Some(self_type.clone()); let func_id = unresolved_trait.method_ids[&name.0.contents]; @@ -127,6 +135,7 @@ impl<'context> Elaborator<'context> { parameters, return_type, where_clause, + body, func_id, ); @@ -135,7 +144,9 @@ impl<'context> Elaborator<'context> { this.interner.set_doc_comments(id, item.doc_comments.clone()); } - let func_meta = this.interner.function_meta(&func_id); + let func_meta = this.interner.function_meta_mut(&func_id); + func_meta.trait_constraints.push(the_trait_constraint); + func_meta.trait_constraints.extend(the_trait_where_clause); let arguments = vecmap(&func_meta.parameters.0, |(_, typ, _)| typ.clone()); let return_type = func_meta.return_type().clone(); @@ -189,11 +200,13 @@ impl<'context> Elaborator<'context> { parameters: &[(Ident, UnresolvedType)], return_type: &FunctionReturnType, where_clause: &[UnresolvedTraitConstraint], + body: &Option, func_id: FuncId, ) { - let old_generic_count = self.generics.len(); - - self.scopes.start_function(); + let body = match body { + Some(body) => body.clone(), + None => BlockExpression { statements: Vec::new() }, + }; let kind = FunctionKind::Normal; let mut def = FunctionDefinition::normal( @@ -201,7 +214,7 @@ impl<'context> Elaborator<'context> { is_unconstrained, generics, parameters, - &BlockExpression { statements: Vec::new() }, + body, where_clause, return_type, ); @@ -210,10 +223,6 @@ impl<'context> Elaborator<'context> { let mut function = NoirFunction { kind, def }; self.define_function_meta(&mut function, func_id, Some(trait_id)); - self.elaborate_function(func_id); - let _ = self.scopes.end_function(); - // Don't check the scope tree for unused variables, they can't be used in a declaration anyway. - self.generics.truncate(old_generic_count); } } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 7e06964b563..8a81bcb666f 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -544,12 +544,17 @@ impl<'context> Elaborator<'context> { } // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) + // or inside a trait default method. // // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option { - let trait_impl = self.current_trait_impl?; - let trait_id = self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id; + let trait_id = if let Some(current_trait) = self.current_trait { + current_trait + } else { + let trait_impl = self.current_trait_impl?; + self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id + }; if path.kind == PathKind::Plain && path.segments.len() == 2 { let name = &path.segments[0].ident.0.contents; diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index e7953aab5a4..1a6366416ef 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -518,7 +518,7 @@ impl<'a> ModCollector<'a> { *is_unconstrained, generics, parameters, - body, + body.clone(), where_clause, return_type, )); diff --git a/compiler/noirc_frontend/src/hir_def/function.rs b/compiler/noirc_frontend/src/hir_def/function.rs index db6c3507b15..f6ee003b179 100644 --- a/compiler/noirc_frontend/src/hir_def/function.rs +++ b/compiler/noirc_frontend/src/hir_def/function.rs @@ -175,12 +175,12 @@ pub enum FunctionBody { impl FuncMeta { /// A stub function does not have a body. This includes Builtin, LowLevel, - /// and Oracle functions in addition to method declarations within a trait. + /// and Oracle functions. /// /// We don't check the return type of these functions since it will always have /// an empty body, and we don't check for unused parameters. pub fn is_stub(&self) -> bool { - self.kind.can_ignore_return_type() || self.trait_id.is_some() + self.kind.can_ignore_return_type() } pub fn function_signature(&self) -> FunctionSignature { diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 605236c8dda..0b52cede19c 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2988,7 +2988,7 @@ fn uses_self_type_inside_trait() { fn uses_self_type_in_trait_where_clause() { let src = r#" pub trait Trait { - fn trait_func() -> bool; + fn trait_func(self) -> bool; } pub trait Foo where Self: Trait { diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs index 811a32bab86..b641f726e47 100644 --- a/compiler/noirc_frontend/src/tests/traits.rs +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -592,7 +592,7 @@ fn trait_bounds_which_are_dependent_on_generic_types_are_resolved_correctly() { // Regression test for https://github.com/noir-lang/noir/issues/6420 let src = r#" trait Foo { - fn foo() -> Field; + fn foo(self) -> Field; } trait Bar: Foo { @@ -613,7 +613,8 @@ fn trait_bounds_which_are_dependent_on_generic_types_are_resolved_correctly() { where T: MarkerTrait, { - fn foo() -> Field { + fn foo(self) -> Field { + let _ = self; 42 } } @@ -652,3 +653,65 @@ fn does_not_crash_on_as_trait_path_with_empty_path() { ); assert!(!errors.is_empty()); } + +#[test] +fn type_checks_trait_default_method_and_errors() { + let src = r#" + pub trait Foo { + fn foo(self) -> i32 { + let _ = self; + true + } + } + + fn main() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::TypeMismatchWithSource { + expected, + actual, + .. + }) = &errors[0].0 + else { + panic!("Expected a type mismatch error, got {:?}", errors[0].0); + }; + + assert_eq!(expected.to_string(), "i32"); + assert_eq!(actual.to_string(), "bool"); +} + +#[test] +fn type_checks_trait_default_method_and_does_not_error() { + let src = r#" + pub trait Foo { + fn foo(self) -> i32 { + let _ = self; + 1 + } + } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn type_checks_trait_default_method_and_does_not_error_using_self() { + let src = r#" + pub trait Foo { + fn foo(self) -> i32 { + self.bar() + } + + fn bar(self) -> i32 { + let _ = self; + 1 + } + } + + fn main() {} + "#; + assert_no_errors(src); +}