Skip to content

Commit

Permalink
Introduce type oracle that maps from optimized IR nodes to unoptimize…
Browse files Browse the repository at this point in the history
…d ones (#1295)

* Improve handling of dead blocks and lambdas

* Introduce map from optimized IR nodes to unoptimized ones

* Almost there

* Fix elimination of dead statics

* Add comment
  • Loading branch information
kasperl authored Dec 21, 2022
1 parent 8767a9c commit c273d25
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 178 deletions.
12 changes: 4 additions & 8 deletions src/compiler/byte_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ void ByteGen::_generate_call(Call* node,

if (node->range().is_valid()) {
int bytecode_position = emitter()->position();
method_mapper_.register_call(bytecode_position, node->range());
method_mapper_.register_call(node, bytecode_position);
}

if (is_for_effect()) __ pop(1);
Expand Down Expand Up @@ -811,9 +811,7 @@ void ByteGen::visit_Typecheck(Typecheck* node) {
int height = local_height(target->as_Local()->index());
bytecode_position = __ typecheck_local(height, typecheck_index);
}
method_mapper_.register_as_check(bytecode_position,
node->range(),
node->type_name().c_str());
method_mapper_.register_as_check(node, bytecode_position);
return;
}

Expand All @@ -828,9 +826,7 @@ void ByteGen::visit_Typecheck(Typecheck* node) {

if (is_as_check) {
int bytecode_position = emitter()->position();
method_mapper_.register_as_check(bytecode_position,
node->range(),
node->type_name().c_str());
method_mapper_.register_as_check(node, bytecode_position);
}
if (is_for_effect()) __ pop(1);
}
Expand Down Expand Up @@ -974,7 +970,7 @@ void ByteGen::visit_ReferenceGlobal(ReferenceGlobal* node) {

__ load_global_var(node->target()->global_id(), is_lazy);
int bytecode_position = emitter()->position();
method_mapper_.register_call(bytecode_position, node->range());
method_mapper_.register_call(node, bytecode_position);

if (is_for_effect()) __ pop(1);
}
Expand Down
24 changes: 19 additions & 5 deletions src/compiler/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ static void check_sdk(const std::string& constraint, Diagnostics* diagnostics) {

toit::Program* construct_program(ir::Program* ir_program,
SourceMapper* source_mapper,
TypeOracle* oracle,
TypeDatabase* propagated_types,
bool run_optimizations) {
source_mapper->register_selectors(ir_program->classes());
Expand All @@ -1515,9 +1516,21 @@ toit::Program* construct_program(ir::Program* ir_program,

ASSERT(_sorted_by_inheritance(ir_program->classes()));

if (run_optimizations) optimize(ir_program, propagated_types);
if (run_optimizations) optimize(ir_program, oracle);
tree_shake(ir_program);

// It is important that we seed and finalize the oracle in the same
// state, so the IR nodes used to produce the somewhat unoptimized
// program that we propagate types through can be matched up to the
// corresponding IR nodes for the fully optimized version.
if (propagated_types) {
oracle->finalize(ir_program, propagated_types);
optimize(ir_program, oracle);
tree_shake(ir_program);
} else {
oracle->seed(ir_program);
}

// We assign the field ids very late in case we can inline field-accesses.
assign_field_indexes(ir_program->classes());
// Similarly, assign the global ids at the end, in case they can be tree
Expand Down Expand Up @@ -1606,7 +1619,8 @@ Pipeline::Result Pipeline::run(List<const char*> source_paths, bool propagate) {

SourceMapper unoptimized_source_mapper(source_manager());
auto source_mapper = &unoptimized_source_mapper;
auto program = construct_program(ir_program, source_mapper, null, run_optimizations);
TypeOracle oracle(source_mapper);
auto program = construct_program(ir_program, source_mapper, &oracle, null, run_optimizations);

SourceMapper optimized_source_mapper(source_manager());
if (run_optimizations && configuration_.optimization_level >= 2) {
Expand All @@ -1621,14 +1635,14 @@ Pipeline::Result Pipeline::run(List<const char*> source_paths, bool propagate) {
// to behave the same way for the output to be correct.
check_types_and_deprecations(ir_program, quiet);
ASSERT(!diagnostics()->encountered_error());
TypeDatabase* types = TypeDatabase::compute(program, source_mapper);
TypeDatabase* types = TypeDatabase::compute(program);
source_mapper = &optimized_source_mapper;
program = construct_program(ir_program, source_mapper, types, true);
program = construct_program(ir_program, source_mapper, &oracle, types, true);
delete types;
}

if (propagate) {
TypeDatabase* types = TypeDatabase::compute(program, null);
TypeDatabase* types = TypeDatabase::compute(program);
auto json = types->as_json();
printf("%s", json.c_str());
delete types;
Expand Down
23 changes: 10 additions & 13 deletions src/compiler/optimizations/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {
}
};

DeadCodeEliminator(TypeDatabase* propagated_types)
: propagated_types_(propagated_types)
explicit DeadCodeEliminator(TypeOracle* oracle)
: oracle_(oracle)
, terminator_(null, Symbol::invalid()) {}

Expression* visit(Expression* node, bool* terminates) {
Expand Down Expand Up @@ -254,9 +254,7 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {

Node* visit_ReferenceGlobal(ReferenceGlobal* node) {
Global* global = node->target();
if (global->is_dead()) {
return is_for_effect() ? terminate(null) : terminate(_new Nop(node->range()));
}
if (global->is_dead()) return terminate(null);
return (global->is_lazy() || is_for_value()) ? node : null;
}

Expand Down Expand Up @@ -304,8 +302,7 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {
int used = 0;
while (used < length && !terminates) {
Expression* result = visit_for_value(arguments[used], &terminates);
arguments[used] = result;
used++;
if (result) arguments[used++] = result;
}

Expression* result = node;
Expand All @@ -321,11 +318,11 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {
}
result = _new Sequence(arguments.sublist(0, used), node->range());
ASSERT(terminates);
} else if (propagated_types_ != null && !node->is_CallBuiltin()) {
} else if (oracle_ != null && !node->is_CallBuiltin()) {
// If we have propagated type information, we might know that
// this call does not return. If so, we make sure to tag the
// result correctly, so we drop code that follows the call.
terminates = propagated_types_->does_not_return(node);
terminates = oracle_->does_not_return(node);
}
return tag(result, terminates);
}
Expand All @@ -350,7 +347,7 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {
if (terminates) break;
}
if (index == 0) {
return is_for_effect() ? terminate(null) : terminate(_new Nop(node->range()));
return terminate(null);
} else {
return terminate(_new Sequence(arguments.sublist(0, index), node->range()));
}
Expand Down Expand Up @@ -428,7 +425,7 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {
Node* visit_FieldStub(FieldStub* node) { return visit_Method(node); }

private:
TypeDatabase* propagated_types_;
TypeOracle* const oracle_;
bool is_for_value_ = false;

bool is_for_value() const { return is_for_value_; }
Expand All @@ -450,8 +447,8 @@ class DeadCodeEliminator : public ReturningVisitor<Node*> {
}
};

void eliminate_dead_code(Method* method, TypeDatabase* propagated_types) {
DeadCodeEliminator eliminator(propagated_types);
void eliminate_dead_code(Method* method, TypeOracle* oracle) {
DeadCodeEliminator eliminator(oracle);
Expression* body = method->body();
if (body == null) return;

Expand Down
2 changes: 1 addition & 1 deletion src/compiler/optimizations/dead_code.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace toit {
namespace compiler {

void eliminate_dead_code(ir::Method* method, TypeDatabase* propagated_types);
void eliminate_dead_code(ir::Method* method, TypeOracle* oracle);

} // namespace toit::compiler
} // namespace toit
Expand Down
35 changes: 14 additions & 21 deletions src/compiler/optimizations/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,27 @@ using namespace ir;

class KillerVisitor : public TraversingVisitor {
public:
KillerVisitor(TypeDatabase* propagated_types)
: propagated_types_(propagated_types) {}
explicit KillerVisitor(TypeOracle* oracle)
: oracle_(oracle) {}

void visit_Method(Method* node) {
TraversingVisitor::visit_Method(node);
if (propagated_types_ && propagated_types_->is_dead(node)) {
node->kill();
}
if (oracle_->is_dead(node)) node->kill();
}

void visit_Code(Code* node) {
TraversingVisitor::visit_Code(node);
if (propagated_types_ && propagated_types_->is_dead(node)) {
node->kill();
}
if (oracle_->is_dead(node)) node->kill();
}

void visit_Global(Global* node) {
TraversingVisitor::visit_Method(node);
mark_if_eager(node);
if (!node->is_lazy()) return;
if (propagated_types_ && propagated_types_->is_dead(node)) {
node->kill();
}
if (node->is_lazy() && oracle_->is_dead(node)) node->kill();
}

private:
TypeDatabase* const propagated_types_;
TypeOracle* const oracle_;

void mark_if_eager(Global* global) {
// This runs after the constant propagation phase, so it is
Expand All @@ -83,10 +76,10 @@ class KillerVisitor : public TraversingVisitor {

class OptimizationVisitor : public ReplacingVisitor {
public:
OptimizationVisitor(TypeDatabase* propagated_types,
OptimizationVisitor(TypeOracle* oracle,
const UnorderedMap<Class*, QueryableClass> queryables,
const UnorderedSet<Symbol>& field_names)
: propagated_types_(propagated_types)
: oracle_(oracle)
, holder_(null)
, method_(null)
, queryables_(queryables)
Expand All @@ -95,9 +88,9 @@ class OptimizationVisitor : public ReplacingVisitor {
Node* visit_Method(Method* node) {
if (node->is_dead()) return node;
method_ = node;
eliminate_dead_code(node, null);
eliminate_dead_code(node, oracle_);
Node* result = ReplacingVisitor::visit_Method(node);
eliminate_dead_code(node, propagated_types_);
eliminate_dead_code(node, oracle_);
method_ = null;
return result;
}
Expand Down Expand Up @@ -134,19 +127,19 @@ class OptimizationVisitor : public ReplacingVisitor {
void set_class(Class* klass) { holder_ = klass; }

private:
TypeDatabase* const propagated_types_;
TypeOracle* const oracle_;

Class* holder_; // Null, if not in class (or a static method/field).
Method* method_;
UnorderedMap<Class*, QueryableClass> queryables_;
UnorderedSet<Symbol> field_names_;
};

void optimize(Program* program, TypeDatabase* propagated_types) {
void optimize(Program* program, TypeOracle* oracle) {
// The constant propagation runs independently, as it builds up its own
// dependency graph.
propagate_constants(program);
KillerVisitor killer(propagated_types);
KillerVisitor killer(oracle);
killer.visit(program);

auto classes = program->classes();
Expand Down Expand Up @@ -176,7 +169,7 @@ void optimize(Program* program, TypeDatabase* propagated_types) {
}
}

OptimizationVisitor visitor(propagated_types, queryables, field_names);
OptimizationVisitor visitor(oracle, queryables, field_names);

for (auto klass : classes) {
visitor.set_class(klass);
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/optimizations/optimizations.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace toit {
namespace compiler {

// Optimizes the program by combining all available sub-optimizations.
void optimize(ir::Program* program, TypeDatabase* propagated_types);
void optimize(ir::Program* program, TypeOracle* oracle);

} // namespace toit::compiler
} // namespace toit
Loading

0 comments on commit c273d25

Please sign in to comment.