From 376406cad6819757721fa1f8c438659e80cd8435 Mon Sep 17 00:00:00 2001 From: Rich McKeever Date: Thu, 9 Jan 2025 20:54:30 -0800 Subject: [PATCH] Support comparison operators in type_system_v2. This change also fixes the handling of expressions used as explicit parametric invocation arguments. PiperOrigin-RevId: 713898572 --- xls/dslx/type_system_v2/inference_table.cc | 9 +- .../type_system_v2/inference_table_test.cc | 29 ---- .../inference_table_to_type_info.cc | 6 + .../type_system_v2/typecheck_module_v2.cc | 64 ++++++++- .../typecheck_module_v2_test.cc | 125 ++++++++++++++++++ 5 files changed, 194 insertions(+), 39 deletions(-) diff --git a/xls/dslx/type_system_v2/inference_table.cc b/xls/dslx/type_system_v2/inference_table.cc index 57dd68ca05..8e3f9fd404 100644 --- a/xls/dslx/type_system_v2/inference_table.cc +++ b/xls/dslx/type_system_v2/inference_table.cc @@ -196,14 +196,7 @@ class InferenceTableImpl : public InferenceTable { callee.parametric_bindings(); const std::vector& explicit_parametrics = node.explicit_parametrics(); - if (explicit_parametrics.size() > bindings.size()) { - return ArgCountMismatchErrorStatus( - node.span(), - absl::Substitute( - "Too many parametric values supplied; limit: $0 given: $1", - callee.parametric_bindings().size(), explicit_parametrics.size()), - file_table_); - } + CHECK(explicit_parametrics.size() <= bindings.size()); absl::flat_hash_map values; for (int i = 0; i < bindings.size(); i++) { const ParametricBinding* binding = bindings[i]; diff --git a/xls/dslx/type_system_v2/inference_table_test.cc b/xls/dslx/type_system_v2/inference_table_test.cc index be16855c2a..01c19d7111 100644 --- a/xls/dslx/type_system_v2/inference_table_test.cc +++ b/xls/dslx/type_system_v2/inference_table_test.cc @@ -364,34 +364,5 @@ TEST_F(InferenceTableTest, ParametricVariableWithUnsupportedAnnotation) { HasSubstr("Inference variables of type T are not supported"))); } -TEST_F(InferenceTableTest, TooManyParametricsInInvocation) { - ParseAndInitModuleAndTable(R"( - fn foo(a: uN[N]) -> uN[N] { a } - fn bar() { - foo(u4:1); - } -)"); - - XLS_ASSERT_OK_AND_ASSIGN(const Function* foo, - module_->GetMemberOrError("foo")); - ASSERT_EQ(foo->parametric_bindings().size(), 1); - ASSERT_EQ(foo->params().size(), 1); - XLS_ASSERT_OK( - table_->DefineParametricVariable(*foo->parametric_bindings()[0])); - for (const Param* param : foo->params()) { - XLS_ASSERT_OK(table_->SetTypeAnnotation(param, param->type_annotation())); - } - XLS_ASSERT_OK_AND_ASSIGN(const Function* bar, - module_->GetMemberOrError("bar")); - ASSERT_EQ(bar->body()->statements().size(), 1); - const Invocation* invocation = down_cast( - ToAstNode(bar->body()->statements().at(0)->wrapped())); - EXPECT_THAT( - table_->AddParametricInvocation(*invocation, *foo, bar, - /*caller_invocation=*/std::nullopt), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Too many parametric values supplied"))); -} - } // namespace } // namespace xls::dslx diff --git a/xls/dslx/type_system_v2/inference_table_to_type_info.cc b/xls/dslx/type_system_v2/inference_table_to_type_info.cc index d23e2dca35..dab3e615d1 100644 --- a/xls/dslx/type_system_v2/inference_table_to_type_info.cc +++ b/xls/dslx/type_system_v2/inference_table_to_type_info.cc @@ -211,6 +211,12 @@ class ConversionOrderVisitor : public AstNodeVisitorWithDefault { absl::Status HandleParametricBindingExprsInternal( const ParametricInvocation* parametric_invocation) { + for (ExprOrType explicit_parametric : + parametric_invocation->node().explicit_parametrics()) { + if (std::holds_alternative(explicit_parametric)) { + XLS_RETURN_IF_ERROR(std::get(explicit_parametric)->Accept(this)); + } + } parametric_invocation_stack_.push(parametric_invocation); for (const ParametricBinding* binding : parametric_invocation->callee().parametric_bindings()) { diff --git a/xls/dslx/type_system_v2/typecheck_module_v2.cc b/xls/dslx/type_system_v2/typecheck_module_v2.cc index 1bb942ca4d..aee72d9688 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2.cc @@ -124,8 +124,27 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { // should have a type that was set when its parent was visited. const NameRef* type_variable = *table_.GetTypeVariable(node); if (GetBinopSameTypeKinds().contains(node->binop_kind())) { + // In the example `const C = a + b;`, the `ConstantDef` establishes a type + // variable that is just propagated down to `a` and `b` here, meaning that + // `a`, `b`, and the result must ultimately be the same type. XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->lhs(), type_variable)); XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->rhs(), type_variable)); + } else if (GetBinopComparisonKinds().contains(node->binop_kind())) { + // In a comparison example, like `const C = a > b;`, the `>` establishes a + // new type variable for `a` and `b` (meaning the two of them must be the + // same type), and attaches a bool annotation to the overall expression, + // which will then be assumed by the type variable for the `ConstantDef`. + XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation( + node, CreateBoolAnnotation(module_, node->span()))); + XLS_ASSIGN_OR_RETURN( + const NameRef* operand_variable, + table_.DefineInternalVariable( + InferenceVariableKind::kType, const_cast(node), + GenerateInternalTypeVariableName(node))); + XLS_RETURN_IF_ERROR( + table_.SetTypeVariable(node->lhs(), operand_variable)); + XLS_RETURN_IF_ERROR( + table_.SetTypeVariable(node->rhs(), operand_variable)); } else { return absl::UnimplementedError( absl::StrCat("Type inference version 2 is a work in progress and " @@ -527,6 +546,9 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { VLOG(5) << "HandleParametricInvocation: " << node->ToString() << ", fn: " << fn.identifier(); CHECK(fn.IsParametric()); + const std::vector& bindings = fn.parametric_bindings(); + const std::vector& explicit_parametrics = + node->explicit_parametrics(); const std::optional caller = GetCurrentFunction(); current_function_stack_.push(&fn); const bool function_processed_before = @@ -537,22 +559,54 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { // The bindings need to be defined in the table up front, because the rest // of the header may depend on them, and we can't even create a // `ParametricInvocation` without them being registered. - for (const ParametricBinding* binding : fn.parametric_bindings()) { + for (const ParametricBinding* binding : bindings) { XLS_RETURN_IF_ERROR(binding->Accept(this)); } } + if (explicit_parametrics.size() > bindings.size()) { + return ArgCountMismatchErrorStatus( + node->span(), + absl::Substitute( + "Too many parametric values supplied; limit: $0 given: $1", + bindings.size(), explicit_parametrics.size()), + file_table_); + } + + // Type-check the subtrees for any explicit parametric values. Note that the + // addition of the invocation above will have verified that a valid number + // of explicit parametrics was passed in. + for (int i = 0; i < explicit_parametrics.size(); i++) { + ExprOrType explicit_parametric = explicit_parametrics[i]; + const ParametricBinding* formal_parametric = bindings[i]; + if (std::holds_alternative(explicit_parametric)) { + const Expr* parametric_value_expr = + std::get(explicit_parametric); + XLS_ASSIGN_OR_RETURN( + const NameRef* type_variable, + table_.DefineInternalVariable( + InferenceVariableKind::kType, + const_cast(parametric_value_expr), + GenerateInternalTypeVariableName(parametric_value_expr))); + XLS_RETURN_IF_ERROR( + table_.SetTypeVariable(parametric_value_expr, type_variable)); + XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation( + parametric_value_expr, formal_parametric->type_annotation())); + XLS_RETURN_IF_ERROR(parametric_value_expr->Accept(this)); + } + } + // Register the parametric invocation in the table, regardless of whether // we have seen the function before. XLS_ASSIGN_OR_RETURN( const ParametricInvocation* parametric_invocation, table_.AddParametricInvocation(*node, fn, caller, GetCurrentParametricInvocation())); - parametric_invocation_stack_.push(parametric_invocation); // We don't need to process the entire function multiple times, if it's // used in multiple contexts. Only the invocation nodes in it need to be // dealt with multiple times. + parametric_invocation_stack_.push(parametric_invocation); if (function_processed_before) { VLOG(5) << "Reprocessing outbound invocations in this context from: " << fn.identifier(); @@ -618,6 +672,12 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { return absl::StrCat("internal_type_actual_member_", formal_member->name(), "_at_", actual_member->span().ToString(file_table_)); } + // Variant for operands of a binary operator. + std::string GenerateInternalTypeVariableName(const Binop* binop) { + return absl::StrCat("internal_type_operand_", + BinopKindToString(binop->binop_kind()), "_at_", + binop->span().ToString(file_table_)); + } // Propagates the type from the def for `ref`, to `ref` itself in the // inference table. This may result in a `TypeAnnotation` being added to the diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index d5223fccb6..09a42a017a 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -366,6 +366,123 @@ const Z = X % 1 % Y % 2; HasSubstr("node: `const Z = X % 1 % Y % 2;`, type: uN[32]"))); } +TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfUntypedLiterals) { + EXPECT_THAT("const Z = 4 > 1;", TopNodeHasType("uN[1]")); +} + +TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfTypedLiterals) { + EXPECT_THAT("const Z = u32:4 < u32:1;", TopNodeHasType("uN[1]")); +} + +TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfLiteralsWithOneType) { + EXPECT_THAT("const Z = 4 < s32:1;", TopNodeHasType("uN[1]")); +} + +TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfVariables) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +const X = u32:3; +const Y = u32:4; +const Z = Y >= X; +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, HasSubstr("node: `Z`, type: uN[1]")); +} + +TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfExprs) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +const X = s24:3; +const Y = s24:4; +const Z = (Y + X * 2) == (1 - Y); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + AllOf(HasSubstr("node: `(Y + X * 2)`, type: sN[24]"), + HasSubstr("node: `(1 - Y)`, type: sN[24]"), + HasSubstr("node: `Z`, type: uN[1]"))); +} + +TEST(TypecheckV2Test, ComparisonAsFunctionArgument) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: bool) -> bool { a } +const Y = foo(1 != 2); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, AllOf(HasSubstr("node: `1 != 2`, type: uN[1]"), + HasSubstr("node: `1`, type: uN[2]"), + HasSubstr("node: `2`, type: uN[2]"))); +} + +TEST(TypecheckV2Test, ComparisonOfReturnValues) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: u32) -> u32 { a } +const Y = foo(1) > foo(2); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + AllOf(HasSubstr("node: `Y`, type: uN[1]"), + HasSubstr("node: `foo(1)`, type: uN[32]"), + HasSubstr("node: `foo(2)`, type: uN[32]"))); +} + +TEST(TypecheckV2Test, ComparisonAsParametricArgument) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: xN[S][32]) -> xN[S][32] { a } +const Y = foo<{2 > 1}>(s32:5); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, AllOf(HasSubstr("node: `Y`, type: sN[32]"), + HasSubstr("node: `2`, type: uN[2]"), + HasSubstr("node: `1`, type: uN[2]"))); +} + +TEST(TypecheckV2Test, ComparisonAsParametricArgumentWithConflictFails) { + EXPECT_THAT(R"( +fn foo(a: xN[S][32]) -> xN[S][32] { a } +const Y = foo<{2 > 1}>(u32:5); +)", + TypecheckFails(HasSignednessMismatch("uN[32]", "sN[32]"))); +} + +TEST(TypecheckV2Test, ComparisonAndSumAsParametricArguments) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +const X = u32:1; +fn foo(a: xN[S][N]) -> xN[S][N] { a } +const Y = foo<{X == 1}, {X + 3}>(s4:3); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, HasSubstr("node: `Y`, type: sN[4]")); +} + +TEST(TypecheckV2Test, ComparisonAndSumParametricArgumentsWithConflictFails) { + EXPECT_THAT(R"( +const X = u32:1; +fn foo(a: xN[S][N]) -> xN[S][N] { a } +const Y = foo<{X == 1}, {X + 4}>(s4:3); +)", + TypecheckFails(HasSizeMismatch("sN[4]", "sN[5]"))); +} + +TEST(TypecheckV2Test, ExplicitParametricExpressionMismatchingBindingTypeFails) { + EXPECT_THAT(R"( +const X = u32:1; +fn foo(a: uN[N]) -> uN[N] { a } +const Y = foo<{X == 1}>(s4:3); +)", + TypecheckFails(HasSizeMismatch("bool", "u32"))); +} + +TEST(TypecheckV2Test, + GlobalBoolConstantEqualsComparisonOfConflictingTypedLiteralsFails) { + EXPECT_THAT("const Z = u32:4 >= s32:1;", + TypecheckFails(HasSignednessMismatch("s32", "u32"))); +} + TEST(TypecheckV2Test, GlobalIntegerConstantEqualsAnotherConstantWithAnnotationOnName) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( @@ -1067,6 +1184,14 @@ const Y = X(1); TypecheckFails(HasSubstr("callee `X` is not a function"))); } +TEST(TypecheckV2Test, ParametricFunctionCallWithTooManyParametricsFails) { + EXPECT_THAT(R"( +fn foo() -> u32 { N } +const X = foo<3, 4>(); +)", + TypecheckFails(HasSubstr("Too many parametric values supplied"))); +} + TEST(TypecheckV2Test, ParametricFunctionReturningIntegerParameter) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( fn foo() -> u32 { N }