From 5f9d7a882b04ae20297720dd1d1220114401f3e4 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Mon, 30 Dec 2024 16:38:47 -0800 Subject: [PATCH] [DSLX:BC] Remove logic from bytecode emitter, consolidate on type information. Fixes google/xls#1784 --- xls/dslx/bytecode/bytecode.h | 3 +- xls/dslx/bytecode/bytecode_emitter.cc | 122 ++++++++---------- xls/dslx/bytecode/bytecode_emitter.h | 4 + xls/dslx/bytecode/bytecode_interpreter.cc | 53 +------- .../bytecode/bytecode_interpreter_test.cc | 22 ++++ xls/dslx/interp_value.cc | 11 +- 6 files changed, 95 insertions(+), 120 deletions(-) diff --git a/xls/dslx/bytecode/bytecode.h b/xls/dslx/bytecode/bytecode.h index ce0ceb9f3c..93b3bab4d9 100644 --- a/xls/dslx/bytecode/bytecode.h +++ b/xls/dslx/bytecode/bytecode.h @@ -161,7 +161,8 @@ class Bytecode { // otherwise it'll be logical. kShr, // Slices out a subset of the bits-typed value on TOS2, - // starting at index TOS1 and ending at index TOS0. + // starting at index TOS1 with bitwidth at TOS0. + // Note: the start index and the bitwidth should both be non-negative. kSlice, // Creates a new proc interpreter using the data in the optional data member // (as a `SpawnData`). diff --git a/xls/dslx/bytecode/bytecode_emitter.cc b/xls/dslx/bytecode/bytecode_emitter.cc index dd326447be..6d32ab5894 100644 --- a/xls/dslx/bytecode/bytecode_emitter.cc +++ b/xls/dslx/bytecode/bytecode_emitter.cc @@ -979,14 +979,59 @@ absl::Status BytecodeEmitter::HandleFormatMacro(const FormatMacro* node) { return absl::OkStatus(); } -static absl::StatusOr GetValueWidth(const TypeInfo* type_info, - Expr* expr) { - std::optional maybe_type = type_info->GetItem(expr); +absl::Status BytecodeEmitter::HandleSlice(const Index* node, Slice* slice) { + std::optional saw = type_info_->GetSliceStartAndWidth( + slice, + caller_bindings_.has_value() ? *caller_bindings_ : ParametricEnv()); + if (!saw.has_value()) { + return absl::InternalError(absl::StrFormat( + "Expected start-and-width data for slice `%s` @ %s to be populated " + "from type checking.", + slice->ToString(), node->span().ToString(file_table()))); + } + + XLS_RET_CHECK_GE(saw->start, 0); + XLS_RET_CHECK_GE(saw->width, 0); + + // Helper for either getting the span of the given slice index or, if that + // slice index is nullptr, getting the span from the index operation as a + // fallback. + auto span_or_default = [&](Expr* slice_index) -> Span { + if (slice_index != nullptr) { + return slice_index->span(); + } + return node->span(); + }; + + bytecode_.push_back(Bytecode(span_or_default(slice->start()), + Bytecode::Op::kLiteral, + InterpValue::MakeU32(saw->start))); + bytecode_.push_back(Bytecode(span_or_default(slice->limit()), + Bytecode::Op::kLiteral, + InterpValue::MakeU32(saw->width))); + + bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSlice)); + return absl::OkStatus(); +} + +absl::Status BytecodeEmitter::HandleWidthSlice(const Index* node, + WidthSlice* width_slice) { + XLS_RETURN_IF_ERROR(width_slice->start()->AcceptExpr(this)); + + std::optional maybe_type = type_info_->GetItem(width_slice->width()); if (!maybe_type.has_value()) { - return absl::InternalError( - "Could not find concrete type for slice component."); + return absl::InternalError(absl::StrCat( + "Could not find concrete type for slice width parameter \"", + width_slice->width()->ToString(), "\".")); } - return maybe_type.value()->GetTotalBitCount()->GetAsInt64(); + + MetaType* type = dynamic_cast(maybe_type.value()); + XLS_RET_CHECK(type != nullptr) << maybe_type.value()->ToString(); + XLS_RET_CHECK(IsBitsLike(*type->wrapped())) << type->ToString(); + + bytecode_.push_back( + Bytecode(node->span(), Bytecode::Op::kWidthSlice, type->CloneToUnique())); + return absl::OkStatus(); } absl::Status BytecodeEmitter::HandleIndex(const Index* node) { @@ -994,73 +1039,12 @@ absl::Status BytecodeEmitter::HandleIndex(const Index* node) { if (std::holds_alternative(node->rhs())) { Slice* slice = std::get(node->rhs()); - if (slice->start() == nullptr) { - int64_t start_width; - if (slice->limit() == nullptr) { - // TODO(rspringer): Define a uniform `usize` to avoid specifying magic - // numbers here. This is the default size used for untyped numbers in - // the typechecker. - start_width = 32; - } else { - XLS_ASSIGN_OR_RETURN(start_width, - GetValueWidth(type_info_, slice->limit())); - } - bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kLiteral, - InterpValue::MakeSBits(start_width, 0))); - } else { - XLS_RETURN_IF_ERROR(slice->start()->AcceptExpr(this)); - } - - if (slice->limit() == nullptr) { - std::optional maybe_type = type_info_->GetItem(node->lhs()); - if (!maybe_type.has_value()) { - return absl::InternalError("Could not find concrete type for slice."); - } - Type* type = maybe_type.value(); - // These will never fail. - absl::StatusOr dim = type->GetTotalBitCount(); - CHECK_OK(dim); - absl::StatusOr width = dim->GetAsInt64(); - CHECK_OK(width); - - int64_t limit_width; - if (slice->start() == nullptr) { - // TODO(rspringer): Define a uniform `usize` to avoid specifying magic - // numbers here. This is the default size used for untyped numbers in - // the typechecker. - limit_width = 32; - } else { - XLS_ASSIGN_OR_RETURN(limit_width, - GetValueWidth(type_info_, slice->start())); - } - bytecode_.push_back( - Bytecode(node->span(), Bytecode::Op::kLiteral, - InterpValue::MakeSBits(limit_width, *width))); - } else { - XLS_RETURN_IF_ERROR(slice->limit()->AcceptExpr(this)); - } - bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSlice)); - return absl::OkStatus(); + return HandleSlice(node, slice); } if (std::holds_alternative(node->rhs())) { WidthSlice* width_slice = std::get(node->rhs()); - XLS_RETURN_IF_ERROR(width_slice->start()->AcceptExpr(this)); - - std::optional maybe_type = type_info_->GetItem(width_slice->width()); - if (!maybe_type.has_value()) { - return absl::InternalError(absl::StrCat( - "Could not find concrete type for slice width parameter \"", - width_slice->width()->ToString(), "\".")); - } - - MetaType* type = dynamic_cast(maybe_type.value()); - XLS_RET_CHECK(type != nullptr) << maybe_type.value()->ToString(); - XLS_RET_CHECK(IsBitsLike(*type->wrapped())) << type->ToString(); - - bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kWidthSlice, - type->CloneToUnique())); - return absl::OkStatus(); + return HandleWidthSlice(node, width_slice); } // Otherwise, it's a regular [array or tuple] index op. diff --git a/xls/dslx/bytecode/bytecode_emitter.h b/xls/dslx/bytecode/bytecode_emitter.h index cb84256e8a..6b3d270aa3 100644 --- a/xls/dslx/bytecode/bytecode_emitter.h +++ b/xls/dslx/bytecode/bytecode_emitter.h @@ -113,7 +113,11 @@ class BytecodeEmitter : public ExprVisitor { absl::Status HandleFunctionRef(const FunctionRef* node) override; absl::Status HandleZeroMacro(const ZeroMacro* node) override; absl::Status HandleAllOnesMacro(const AllOnesMacro* node) override; + absl::Status HandleIndex(const Index* node) override; + absl::Status HandleSlice(const Index* node, Slice* slice); + absl::Status HandleWidthSlice(const Index* node, WidthSlice* width_slice); + absl::Status HandleInvocation(const Invocation* node) override; absl::Status HandleLet(const Let* node) override; absl::Status HandleMatch(const Match* node) override; diff --git a/xls/dslx/bytecode/bytecode_interpreter.cc b/xls/dslx/bytecode/bytecode_interpreter.cc index c78a08d75d..cd9b8b67cd 100644 --- a/xls/dslx/bytecode/bytecode_interpreter.cc +++ b/xls/dslx/bytecode/bytecode_interpreter.cc @@ -1253,54 +1253,15 @@ absl::Status BytecodeInterpreter::EvalShr(const Bytecode& bytecode) { } absl::Status BytecodeInterpreter::EvalSlice(const Bytecode& bytecode) { - XLS_ASSIGN_OR_RETURN(InterpValue limit, Pop()); + XLS_ASSIGN_OR_RETURN(InterpValue length, Pop()); XLS_ASSIGN_OR_RETURN(InterpValue start, Pop()); XLS_ASSIGN_OR_RETURN(InterpValue basis, Pop()); - XLS_ASSIGN_OR_RETURN(int64_t basis_bit_count, basis.GetBitCount()); - XLS_ASSIGN_OR_RETURN(int64_t start_bit_count, start.GetBitCount()); - - InterpValue zero = InterpValue::MakeSBits(start_bit_count, 0); - InterpValue basis_length = - InterpValue::MakeSBits(start_bit_count, basis_bit_count); - - XLS_ASSIGN_OR_RETURN(InterpValue start_lt_zero, start.Lt(zero)); - if (start_lt_zero.IsTrue()) { - // Remember, start is negative if we're here. - XLS_ASSIGN_OR_RETURN(start, basis_length.Add(start)); - // If start is _still_ less than zero, then we clamp to zero. - XLS_ASSIGN_OR_RETURN(start_lt_zero, start.Lt(zero)); - if (start_lt_zero.IsTrue()) { - start = zero; - } - } - - XLS_ASSIGN_OR_RETURN(InterpValue limit_lt_zero, limit.Lt(zero)); - if (limit_lt_zero.IsTrue()) { - // Ditto. - XLS_ASSIGN_OR_RETURN(limit, basis_length.Add(limit)); - XLS_ASSIGN_OR_RETURN(limit_lt_zero, limit.Lt(zero)); - if (limit_lt_zero.IsTrue()) { - limit = zero; - } - } - - // If limit extends past the basis, then we truncate limit. - XLS_ASSIGN_OR_RETURN(InterpValue limit_ge_basis_length, - limit.Ge(basis_length)); - if (limit_ge_basis_length.IsTrue()) { - limit = - InterpValue::MakeSBits(start_bit_count, basis.GetBitCount().value()); - } - XLS_ASSIGN_OR_RETURN(InterpValue length, limit.Sub(start)); - - // At this point, both start and length must be nonnegative, so we force them - // to UBits, since Slice expects that. - XLS_ASSIGN_OR_RETURN(int64_t start_value, start.GetBitValueViaSign()); - XLS_ASSIGN_OR_RETURN(int64_t length_value, length.GetBitValueViaSign()); - XLS_RET_CHECK_GE(start_value, 0); - XLS_RET_CHECK_GE(length_value, 0); - start = InterpValue::MakeBits(/*is_signed=*/false, start.GetBitsOrDie()); - length = InterpValue::MakeBits(/*is_signed=*/false, length.GetBitsOrDie()); + XLS_RET_CHECK(length.IsUBits()) + << "Slice length is not unsigned bits: " << length.ToString(); + XLS_RET_CHECK(start.IsUBits()) + << "Slice start is not unsigned bits: " << start.ToString(); + XLS_RET_CHECK(basis.IsUBits()) + << "Slice basis is not unsigned bits: " << basis.ToString(); XLS_ASSIGN_OR_RETURN(InterpValue result, basis.Slice(start, length)); stack_.Push(result); return absl::OkStatus(); diff --git a/xls/dslx/bytecode/bytecode_interpreter_test.cc b/xls/dslx/bytecode/bytecode_interpreter_test.cc index 574162b222..d135574b06 100644 --- a/xls/dslx/bytecode/bytecode_interpreter_test.cc +++ b/xls/dslx/bytecode/bytecode_interpreter_test.cc @@ -147,6 +147,9 @@ TEST_F(BytecodeInterpreterTest, DupLiteral) { EXPECT_EQ(result.ToString(), "u32:42"); } +// Note: this test case spews a stack trace as the `!stack_.empty()` error comes +// from a `XLS_RET_CHECK`, this is really an internals-flags-an-error test not a +// behavioral test. TEST_F(BytecodeInterpreterTest, DupEmptyStack) { std::vector bytecodes; bytecodes.emplace_back(kFakeSpan, Bytecode::Op::kDup); @@ -1318,6 +1321,25 @@ fn negative_end_slice() -> u16 { EXPECT_EQ(int_value, 0xbeef); } +// https://github.com/google/xls/issues/1784 -- note that the size of a slice +// can never be negative, but the bytecode interpreter can have inflated +// expectations that don't line up with what the type checker will accept. This +// test is to ensure that we don't crash when we encounter this case. +TEST_F(BytecodeInterpreterTest, NegativeSizeSlice) { + constexpr std::string_view kProgram = R"( +fn negative_size_slice() -> bits[0] { + (u32:0x42)[5:3] +} +)"; + + XLS_ASSERT_OK_AND_ASSIGN(InterpValue value, + Interpret(kProgram, "negative_size_slice")); + ASSERT_TRUE(value.IsUBits()); + const Bits& bits = value.GetBitsOrDie(); + EXPECT_EQ(bits.bit_count(), 0); + EXPECT_EQ(bits.ToUint64().value(), 0); +} + TEST_F(BytecodeInterpreterTest, WidthSlice) { constexpr std::string_view kProgram = R"( fn width_slice() -> s16 { diff --git a/xls/dslx/interp_value.cc b/xls/dslx/interp_value.cc index eb6dcb1bc6..f9a458e904 100644 --- a/xls/dslx/interp_value.cc +++ b/xls/dslx/interp_value.cc @@ -451,9 +451,11 @@ bool InterpValue::operator==(const InterpValue& rhs) const { return Eq(rhs); } const InterpValue& lhs, const InterpValue& rhs, CompareF ucmp, CompareF scmp) { if (lhs.tag_ != rhs.tag_) { - return absl::InvalidArgumentError(absl::StrFormat( - "Same tag is required for a comparison operation: lhs %s rhs %s", - TagToString(lhs.tag_), TagToString(rhs.tag_))); + return absl::InvalidArgumentError( + absl::StrFormat("Same tag is required for a comparison operation: lhs " + "tag: %s, rhs tag: %s, lhs value: %s, rhs value: %s", + TagToString(lhs.tag_), TagToString(rhs.tag_), + lhs.ToString(), rhs.ToString())); } switch (lhs.tag_) { case InterpValueTag::kUBits: @@ -692,7 +694,8 @@ absl::StatusOr InterpValue::GetBits() const { return std::get(payload_).value; } - return absl::InvalidArgumentError("Value does not contain bits."); + return absl::InvalidArgumentError( + absl::StrFormat("Value %s does not contain bits.", ToString())); } const Bits& InterpValue::GetBitsOrDie() const {