Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DSLX:BC] Remove logic from bytecode emitter, consolidate on type information #1818

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xls/dslx/bytecode/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class Bytecode {
// otherwise it'll be logical.
kShr,
// Slices out a subset of the bits-typed value on TOS2,
// starting at index TOS1 and ending at index TOS0.
// starting at index TOS1 with bitwidth at TOS0.
// Note: the start index and the bitwidth should both be non-negative.
kSlice,
// Creates a new proc interpreter using the data in the optional data member
// (as a `SpawnData`).
Expand Down
122 changes: 53 additions & 69 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -979,88 +979,72 @@ absl::Status BytecodeEmitter::HandleFormatMacro(const FormatMacro* node) {
return absl::OkStatus();
}

static absl::StatusOr<int64_t> GetValueWidth(const TypeInfo* type_info,
Expr* expr) {
std::optional<Type*> maybe_type = type_info->GetItem(expr);
absl::Status BytecodeEmitter::HandleSlice(const Index* node, Slice* slice) {
std::optional<StartAndWidth> saw = type_info_->GetSliceStartAndWidth(
slice,
caller_bindings_.has_value() ? *caller_bindings_ : ParametricEnv());
if (!saw.has_value()) {
return absl::InternalError(absl::StrFormat(
"Expected start-and-width data for slice `%s` @ %s to be populated "
"from type checking.",
slice->ToString(), node->span().ToString(file_table())));
}

XLS_RET_CHECK_GE(saw->start, 0);
XLS_RET_CHECK_GE(saw->width, 0);

// Helper for either getting the span of the given slice index or, if that
// slice index is nullptr, getting the span from the index operation as a
// fallback.
auto span_or_default = [&](Expr* slice_index) -> Span {
if (slice_index != nullptr) {
return slice_index->span();
}
return node->span();
};

bytecode_.push_back(Bytecode(span_or_default(slice->start()),
Bytecode::Op::kLiteral,
InterpValue::MakeU32(saw->start)));
bytecode_.push_back(Bytecode(span_or_default(slice->limit()),
Bytecode::Op::kLiteral,
InterpValue::MakeU32(saw->width)));

bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSlice));
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleWidthSlice(const Index* node,
WidthSlice* width_slice) {
XLS_RETURN_IF_ERROR(width_slice->start()->AcceptExpr(this));

std::optional<Type*> maybe_type = type_info_->GetItem(width_slice->width());
if (!maybe_type.has_value()) {
return absl::InternalError(
"Could not find concrete type for slice component.");
return absl::InternalError(absl::StrCat(
"Could not find concrete type for slice width parameter \"",
width_slice->width()->ToString(), "\"."));
}
return maybe_type.value()->GetTotalBitCount()->GetAsInt64();

MetaType* type = dynamic_cast<MetaType*>(maybe_type.value());
XLS_RET_CHECK(type != nullptr) << maybe_type.value()->ToString();
XLS_RET_CHECK(IsBitsLike(*type->wrapped())) << type->ToString();

bytecode_.push_back(
Bytecode(node->span(), Bytecode::Op::kWidthSlice, type->CloneToUnique()));
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleIndex(const Index* node) {
XLS_RETURN_IF_ERROR(node->lhs()->AcceptExpr(this));

if (std::holds_alternative<Slice*>(node->rhs())) {
Slice* slice = std::get<Slice*>(node->rhs());
if (slice->start() == nullptr) {
int64_t start_width;
if (slice->limit() == nullptr) {
// TODO(rspringer): Define a uniform `usize` to avoid specifying magic
// numbers here. This is the default size used for untyped numbers in
// the typechecker.
start_width = 32;
} else {
XLS_ASSIGN_OR_RETURN(start_width,
GetValueWidth(type_info_, slice->limit()));
}
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kLiteral,
InterpValue::MakeSBits(start_width, 0)));
} else {
XLS_RETURN_IF_ERROR(slice->start()->AcceptExpr(this));
}

if (slice->limit() == nullptr) {
std::optional<Type*> maybe_type = type_info_->GetItem(node->lhs());
if (!maybe_type.has_value()) {
return absl::InternalError("Could not find concrete type for slice.");
}
Type* type = maybe_type.value();
// These will never fail.
absl::StatusOr<TypeDim> dim = type->GetTotalBitCount();
CHECK_OK(dim);
absl::StatusOr<int64_t> width = dim->GetAsInt64();
CHECK_OK(width);

int64_t limit_width;
if (slice->start() == nullptr) {
// TODO(rspringer): Define a uniform `usize` to avoid specifying magic
// numbers here. This is the default size used for untyped numbers in
// the typechecker.
limit_width = 32;
} else {
XLS_ASSIGN_OR_RETURN(limit_width,
GetValueWidth(type_info_, slice->start()));
}
bytecode_.push_back(
Bytecode(node->span(), Bytecode::Op::kLiteral,
InterpValue::MakeSBits(limit_width, *width)));
} else {
XLS_RETURN_IF_ERROR(slice->limit()->AcceptExpr(this));
}
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSlice));
return absl::OkStatus();
return HandleSlice(node, slice);
}

if (std::holds_alternative<WidthSlice*>(node->rhs())) {
WidthSlice* width_slice = std::get<WidthSlice*>(node->rhs());
XLS_RETURN_IF_ERROR(width_slice->start()->AcceptExpr(this));

std::optional<Type*> maybe_type = type_info_->GetItem(width_slice->width());
if (!maybe_type.has_value()) {
return absl::InternalError(absl::StrCat(
"Could not find concrete type for slice width parameter \"",
width_slice->width()->ToString(), "\"."));
}

MetaType* type = dynamic_cast<MetaType*>(maybe_type.value());
XLS_RET_CHECK(type != nullptr) << maybe_type.value()->ToString();
XLS_RET_CHECK(IsBitsLike(*type->wrapped())) << type->ToString();

bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kWidthSlice,
type->CloneToUnique()));
return absl::OkStatus();
return HandleWidthSlice(node, width_slice);
}

// Otherwise, it's a regular [array or tuple] index op.
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ class BytecodeEmitter : public ExprVisitor {
absl::Status HandleFunctionRef(const FunctionRef* node) override;
absl::Status HandleZeroMacro(const ZeroMacro* node) override;
absl::Status HandleAllOnesMacro(const AllOnesMacro* node) override;

absl::Status HandleIndex(const Index* node) override;
absl::Status HandleSlice(const Index* node, Slice* slice);
absl::Status HandleWidthSlice(const Index* node, WidthSlice* width_slice);

absl::Status HandleInvocation(const Invocation* node) override;
absl::Status HandleLet(const Let* node) override;
absl::Status HandleMatch(const Match* node) override;
Expand Down
53 changes: 7 additions & 46 deletions xls/dslx/bytecode/bytecode_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1253,54 +1253,15 @@ absl::Status BytecodeInterpreter::EvalShr(const Bytecode& bytecode) {
}

absl::Status BytecodeInterpreter::EvalSlice(const Bytecode& bytecode) {
XLS_ASSIGN_OR_RETURN(InterpValue limit, Pop());
XLS_ASSIGN_OR_RETURN(InterpValue length, Pop());
XLS_ASSIGN_OR_RETURN(InterpValue start, Pop());
XLS_ASSIGN_OR_RETURN(InterpValue basis, Pop());
XLS_ASSIGN_OR_RETURN(int64_t basis_bit_count, basis.GetBitCount());
XLS_ASSIGN_OR_RETURN(int64_t start_bit_count, start.GetBitCount());

InterpValue zero = InterpValue::MakeSBits(start_bit_count, 0);
InterpValue basis_length =
InterpValue::MakeSBits(start_bit_count, basis_bit_count);

XLS_ASSIGN_OR_RETURN(InterpValue start_lt_zero, start.Lt(zero));
if (start_lt_zero.IsTrue()) {
// Remember, start is negative if we're here.
XLS_ASSIGN_OR_RETURN(start, basis_length.Add(start));
// If start is _still_ less than zero, then we clamp to zero.
XLS_ASSIGN_OR_RETURN(start_lt_zero, start.Lt(zero));
if (start_lt_zero.IsTrue()) {
start = zero;
}
}

XLS_ASSIGN_OR_RETURN(InterpValue limit_lt_zero, limit.Lt(zero));
if (limit_lt_zero.IsTrue()) {
// Ditto.
XLS_ASSIGN_OR_RETURN(limit, basis_length.Add(limit));
XLS_ASSIGN_OR_RETURN(limit_lt_zero, limit.Lt(zero));
if (limit_lt_zero.IsTrue()) {
limit = zero;
}
}

// If limit extends past the basis, then we truncate limit.
XLS_ASSIGN_OR_RETURN(InterpValue limit_ge_basis_length,
limit.Ge(basis_length));
if (limit_ge_basis_length.IsTrue()) {
limit =
InterpValue::MakeSBits(start_bit_count, basis.GetBitCount().value());
}
XLS_ASSIGN_OR_RETURN(InterpValue length, limit.Sub(start));

// At this point, both start and length must be nonnegative, so we force them
// to UBits, since Slice expects that.
XLS_ASSIGN_OR_RETURN(int64_t start_value, start.GetBitValueViaSign());
XLS_ASSIGN_OR_RETURN(int64_t length_value, length.GetBitValueViaSign());
XLS_RET_CHECK_GE(start_value, 0);
XLS_RET_CHECK_GE(length_value, 0);
start = InterpValue::MakeBits(/*is_signed=*/false, start.GetBitsOrDie());
length = InterpValue::MakeBits(/*is_signed=*/false, length.GetBitsOrDie());
XLS_RET_CHECK(length.IsUBits())
<< "Slice length is not unsigned bits: " << length.ToString();
XLS_RET_CHECK(start.IsUBits())
<< "Slice start is not unsigned bits: " << start.ToString();
XLS_RET_CHECK(basis.IsUBits())
<< "Slice basis is not unsigned bits: " << basis.ToString();
XLS_ASSIGN_OR_RETURN(InterpValue result, basis.Slice(start, length));
stack_.Push(result);
return absl::OkStatus();
Expand Down
22 changes: 22 additions & 0 deletions xls/dslx/bytecode/bytecode_interpreter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ TEST_F(BytecodeInterpreterTest, DupLiteral) {
EXPECT_EQ(result.ToString(), "u32:42");
}

// Note: this test case spews a stack trace as the `!stack_.empty()` error comes
// from a `XLS_RET_CHECK`, this is really an internals-flags-an-error test not a
// behavioral test.
TEST_F(BytecodeInterpreterTest, DupEmptyStack) {
std::vector<Bytecode> bytecodes;
bytecodes.emplace_back(kFakeSpan, Bytecode::Op::kDup);
Expand Down Expand Up @@ -1318,6 +1321,25 @@ fn negative_end_slice() -> u16 {
EXPECT_EQ(int_value, 0xbeef);
}

// https://github.com/google/xls/issues/1784 -- note that the size of a slice
// can never be negative, but the bytecode interpreter can have inflated
// expectations that don't line up with what the type checker will accept. This
// test is to ensure that we don't crash when we encounter this case.
TEST_F(BytecodeInterpreterTest, NegativeSizeSlice) {
constexpr std::string_view kProgram = R"(
fn negative_size_slice() -> bits[0] {
(u32:0x42)[5:3]
}
)";

XLS_ASSERT_OK_AND_ASSIGN(InterpValue value,
Interpret(kProgram, "negative_size_slice"));
ASSERT_TRUE(value.IsUBits());
const Bits& bits = value.GetBitsOrDie();
EXPECT_EQ(bits.bit_count(), 0);
EXPECT_EQ(bits.ToUint64().value(), 0);
}

TEST_F(BytecodeInterpreterTest, WidthSlice) {
constexpr std::string_view kProgram = R"(
fn width_slice() -> s16 {
Expand Down
11 changes: 7 additions & 4 deletions xls/dslx/interp_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,11 @@ bool InterpValue::operator==(const InterpValue& rhs) const { return Eq(rhs); }
const InterpValue& lhs, const InterpValue& rhs, CompareF ucmp,
CompareF scmp) {
if (lhs.tag_ != rhs.tag_) {
return absl::InvalidArgumentError(absl::StrFormat(
"Same tag is required for a comparison operation: lhs %s rhs %s",
TagToString(lhs.tag_), TagToString(rhs.tag_)));
return absl::InvalidArgumentError(
absl::StrFormat("Same tag is required for a comparison operation: lhs "
"tag: %s, rhs tag: %s, lhs value: %s, rhs value: %s",
TagToString(lhs.tag_), TagToString(rhs.tag_),
lhs.ToString(), rhs.ToString()));
}
switch (lhs.tag_) {
case InterpValueTag::kUBits:
Expand Down Expand Up @@ -692,7 +694,8 @@ absl::StatusOr<Bits> InterpValue::GetBits() const {
return std::get<EnumData>(payload_).value;
}

return absl::InvalidArgumentError("Value does not contain bits.");
return absl::InvalidArgumentError(
absl::StrFormat("Value %s does not contain bits.", ToString()));
}

const Bits& InterpValue::GetBitsOrDie() const {
Expand Down