diff --git a/CHANGELOG.md b/CHANGELOG.md index f441ed5d1de..37bddc218a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,37 @@ containing scientific notation or trailing zeros (i.e. `100` and `1e2`). ([ptdewey](https://github.com/ptdewey)) +- Type inference now preserves generic type parameters when constructors or functions are used without explicit annotations, eliminating false errors in mutually recursive code: + ```gleam + type Test(a) { + Test(a) + } + + fn it(value: Test(a)) { + it2(value) + } + + fn it2(value: Test(a)) -> Test(a) { + it(value) + } + ``` + Previously this could fail with an incorrect "Type mismatch" error: + ``` + Type mismatch + + The type of this returned value doesn't match the return type + annotation of this function. + + Expected type: + + Test(a) + + Found type: + + Test(a) + ``` + ([Adi Salimgereyev](https://github.com/abs0luty)) + ### Build tool - The help text displayed by `gleam dev --help`, `gleam test --help`, and diff --git a/compiler-core/src/analyse.rs b/compiler-core/src/analyse.rs index 49454e9afb6..90bee85e47f 100644 --- a/compiler-core/src/analyse.rs +++ b/compiler-core/src/analyse.rs @@ -673,8 +673,18 @@ impl<'a, A> ModuleAnalyzer<'a, A> { } // Assert that the inferred type matches the type of any recursive call - if let Err(error) = unify(preregistered_type.clone(), type_) { - self.problems.error(convert_unify_error(error, location)); + if let Err(error) = unify(preregistered_type.clone(), type_.clone()) { + let mut instantiated_ids = im::HashMap::new(); + let flexible_hydrator = Hydrator::new(); + let instantiated_annotation = environment.instantiate( + preregistered_type.clone(), + &mut instantiated_ids, + &flexible_hydrator, + ); + + if unify(instantiated_annotation, type_.clone()).is_err() { + self.problems.error(convert_unify_error(error, location)); + } } // Ensure that the current target has an implementation for the function. @@ -715,10 +725,13 @@ impl<'a, A> ModuleAnalyzer<'a, A> { purity, }; + // Store the inferred type (not the preregistered type) in the environment. + // This ensures concrete type information flows through recursive calls - e.g., if we infer + // `fn() -> Test(Int)`, callers see that instead of the generic `fn() -> Test(a)`. environment.insert_variable( name.clone(), variant, - preregistered_type.clone(), + type_.clone(), publicity, deprecation.clone(), ); @@ -731,6 +744,8 @@ impl<'a, A> ModuleAnalyzer<'a, A> { ReferenceKind::Definition, ); + // Use the inferred return type for the typed AST node. + // This matches the type stored in the environment above. let function = Function { documentation: doc, location, @@ -741,7 +756,7 @@ impl<'a, A> ModuleAnalyzer<'a, A> { body_start, end_position: end_location, return_annotation, - return_type: preregistered_type + return_type: type_ .return_type() .expect("Could not find return type for fn"), body, diff --git a/compiler-core/src/language_server/code_action.rs b/compiler-core/src/language_server/code_action.rs index 5ece4902e2d..566572484de 100644 --- a/compiler-core/src/language_server/code_action.rs +++ b/compiler-core/src/language_server/code_action.rs @@ -1371,6 +1371,73 @@ fn collect_type_variables(printer: &mut Printer<'_>, function: &ast::TypedFuncti } impl<'ast, 'a, 'b> ast::visit::Visit<'ast> for TypeVariableCollector<'a, 'b> { + fn visit_typed_function(&mut self, fun: &'ast ast::TypedFunction) { + for argument in fun.arguments.iter() { + if let Some(annotation) = &argument.annotation { + register_type_variables_from_annotation( + self.printer, + annotation, + argument.type_.as_ref(), + ); + } + } + + if let Some(annotation) = &fun.return_annotation { + register_type_variables_from_annotation( + self.printer, + annotation, + fun.return_type.as_ref(), + ); + } + + ast::visit::visit_typed_function(self, fun); + } + + fn visit_typed_expr_fn( + &mut self, + location: &'ast SrcSpan, + type_: &'ast Arc, + kind: &'ast FunctionLiteralKind, + arguments: &'ast [TypedArg], + body: &'ast Vec1, + return_annotation: &'ast Option, + ) { + if let Type::Fn { + arguments: argument_types, + return_: return_type, + .. + } = type_.as_ref() + { + for (argument, argument_type) in arguments.iter().zip(argument_types) { + if let Some(annotation) = &argument.annotation { + register_type_variables_from_annotation( + self.printer, + annotation, + argument_type.as_ref(), + ); + } + } + + if let Some(annotation) = return_annotation { + register_type_variables_from_annotation( + self.printer, + annotation, + return_type.as_ref(), + ); + } + } + + ast::visit::visit_typed_expr_fn( + self, + location, + type_, + kind, + arguments, + body, + return_annotation, + ); + } + fn visit_type_ast_var(&mut self, _location: &'ast SrcSpan, name: &'ast EcoString) { // Register this type variable so that we don't duplicate names when // adding annotations. @@ -1378,6 +1445,113 @@ impl<'ast, 'a, 'b> ast::visit::Visit<'ast> for TypeVariableCollector<'a, 'b> { } } +fn register_type_variables_from_annotation( + printer: &mut Printer<'_>, + annotation: &ast::TypeAst, + type_: &Type, +) { + // fn wibble(a, b, c) { + // fn(a: b, b: c) -> d { ... } + // ^ + // Without this tracking the printer could rename `d` to a fresh `h`. + match (annotation, type_) { + (ast::TypeAst::Var(ast::TypeAstVar { name, .. }), Type::Var { type_ }) => { + match &*type_.borrow() { + TypeVar::Generic { id } | TypeVar::Unbound { id } => { + let id = *id; + printer.register_type_variable(name.clone()); + printer.register_type_variable_with_id(id, name.clone()); + } + TypeVar::Link { type_ } => { + register_type_variables_from_annotation(printer, annotation, type_.as_ref()); + } + } + } + + ( + ast::TypeAst::Fn(ast::TypeAstFn { + arguments: annotation_arguments, + return_: annotation_return, + .. + }), + Type::Fn { + arguments: type_arguments, + return_: type_return, + .. + }, + ) => { + for (argument_annotation, argument_type) in + annotation_arguments.iter().zip(type_arguments) + { + // Maintain the names from each `fn(arg: name, ...)` position. + register_type_variables_from_annotation( + printer, + argument_annotation, + argument_type.as_ref(), + ); + } + + // And likewise propagate the annotated return variable. + register_type_variables_from_annotation( + printer, + annotation_return.as_ref(), + type_return.as_ref(), + ); + } + + ( + ast::TypeAst::Constructor(ast::TypeAstConstructor { + arguments: annotation_arguments, + .. + }), + Type::Named { + arguments: type_arguments, + .. + }, + ) => { + for (argument_annotation, argument_type) in + annotation_arguments.iter().zip(type_arguments) + { + // Track aliases introduced inside named type arguments. + register_type_variables_from_annotation( + printer, + argument_annotation, + argument_type.as_ref(), + ); + } + } + + ( + ast::TypeAst::Tuple(ast::TypeAstTuple { + elements: annotation_elements, + .. + }), + Type::Tuple { + elements: type_elements, + .. + }, + ) => { + for (element_annotation, element_type) in annotation_elements.iter().zip(type_elements) + { + // Tuples can hide extra annotations; ensure each slot retains its label. + register_type_variables_from_annotation( + printer, + element_annotation, + element_type.as_ref(), + ); + } + } + + (_, Type::Var { type_ }) => { + if let TypeVar::Link { type_ } = &*type_.borrow() { + register_type_variables_from_annotation(printer, annotation, type_.as_ref()); + } + } + + _ => {} + } +} + pub struct QualifiedConstructor<'a> { import: &'a Import, used_name: EcoString, diff --git a/compiler-core/src/type_/expression.rs b/compiler-core/src/type_/expression.rs index 1398ac9051f..471a1f37e65 100644 --- a/compiler-core/src/type_/expression.rs +++ b/compiler-core/src/type_/expression.rs @@ -4707,31 +4707,39 @@ impl<'a, 'b> ExprTyper<'a, 'b> { if let Ok(body) = Vec1::try_from_vec(body) { let mut body = body_typer.infer_statements(body); - // Check that any return type is accurate. - if let Some(return_type) = return_type - && let Err(error) = unify(return_type, body.last().type_()) - { - let error = error - .return_annotation_mismatch() - .into_error(body.last().type_defining_location()); - body_typer.problems.error(error); - - // If the return type doesn't match with the annotation we - // add a new expression to the end of the function to match - // the annotated type and allow type inference to keep - // going. - body.push(Statement::Expression(TypedExpr::Invalid { - // This is deliberately an empty span since this - // placeholder expression is implicitly inserted by the - // compiler and doesn't actually appear in the source - // code. - location: SrcSpan { - start: body.last().location().end, - end: body.last().location().end, - }, - type_: body_typer.new_unbound_var(), - extra_information: None, - })) + // Check that any return type is compatible with the annotation. + if let Some(return_type) = return_type { + let mut instantiated_ids = hashmap![]; + let flexible_hydrator = Hydrator::new(); + let instantiated_annotation = body_typer.environment.instantiate( + return_type.clone(), + &mut instantiated_ids, + &flexible_hydrator, + ); + + if let Err(error) = unify(instantiated_annotation, body.last().type_()) { + let error = error + .return_annotation_mismatch() + .into_error(body.last().type_defining_location()); + body_typer.problems.error(error); + + // If the return type doesn't match with the annotation we + // add a new expression to the end of the function to match + // the annotated type and allow type inference to keep + // going. + body.push(Statement::Expression(TypedExpr::Invalid { + // This is deliberately an empty span since this + // placeholder expression is implicitly inserted by the + // compiler and doesn't actually appear in the source + // code. + location: SrcSpan { + start: body.last().location().end, + end: body.last().location().end, + }, + type_: body_typer.new_unbound_var(), + extra_information: None, + })) + } }; Ok((arguments, body.to_vec())) diff --git a/compiler-core/src/type_/printer.rs b/compiler-core/src/type_/printer.rs index 34dc3829e5c..8b11ff8cf0f 100644 --- a/compiler-core/src/type_/printer.rs +++ b/compiler-core/src/type_/printer.rs @@ -454,6 +454,12 @@ impl<'a> Printer<'a> { _ = self.printed_type_variable_names.insert(name); } + /// Record that the given type-variable id is already using the supplied name. + pub fn register_type_variable_with_id(&mut self, id: u64, name: EcoString) { + _ = self.printed_type_variable_names.insert(name.clone()); + _ = self.printed_type_variables.insert(id, name); + } + pub fn print_type(&mut self, type_: &Type) -> EcoString { let mut buffer = EcoString::new(); self.print(type_, &mut buffer, PrintMode::Normal); diff --git a/compiler-core/src/type_/tests/functions.rs b/compiler-core/src/type_/tests/functions.rs index 7bac5c0813e..8e2b380dc68 100644 --- a/compiler-core/src/type_/tests/functions.rs +++ b/compiler-core/src/type_/tests/functions.rs @@ -169,6 +169,56 @@ pub fn two(x) { ); } +// https://github.com/gleam-lang/gleam/issues/2550 +#[test] +fn mutual_recursion_keeps_generic_return_annotation() { + assert_module_infer!( + r#" +pub type Test(a) { + Test(a) +} + +pub fn it(value: Test(a)) { + it2(value) +} + +pub fn it2(value: Test(a)) -> Test(a) { + it(value) +} + +pub fn main() { + it(Test(1)) +} +"#, + vec![ + (r#"Test"#, r#"fn(a) -> Test(a)"#), + (r#"it"#, r#"fn(Test(a)) -> Test(a)"#), + (r#"it2"#, r#"fn(Test(a)) -> Test(a)"#), + (r#"main"#, r#"fn() -> Test(Int)"#) + ] + ); +} + +// https://github.com/gleam-lang/gleam/issues/2533 +#[test] +fn unbound_type_variable_in_top_level_definition() { + assert_module_infer!( + r#" +pub type Foo(a) { + Foo(value: Int) +} + +pub fn main() { + Foo(1) +} +"#, + vec![ + (r#"Foo"#, r#"fn(Int) -> Foo(a)"#), + (r#"main"#, r#"fn() -> Foo(a)"#), + ] + ); +} + #[test] fn no_impl_function_fault_tolerance() { // A function not having an implementation does not stop analysis.