Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6 #25605

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
];
let constructor = "::mlir::tpu::createInferVectorLayoutPass()";
let options = [
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation = -1);

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128});

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
Expand Down
5 changes: 4 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ int getTilingFactor(const int num_lanes, const int hardware_generation,
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
return sublane_count * 4;
}
if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
// 16-bit values are generally always possible to relayout on the fly in v6,
// so we allow large 2nd minor tiling whenever possible.
if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor ||
hardware_generation >= 6)) {
return sublane_count * 2;
}
return sublane_count;
Expand Down
33 changes: 24 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim,
// have corresponding native instructions.
class VectorLayoutInferer {
public:
explicit VectorLayoutInferer(std::array<int64_t, 2> target_shape)
: target_shape_({target_shape[0], target_shape[1]}),
explicit VectorLayoutInferer(int hardware_generation,
std::array<int64_t, 2> target_shape)
: hardware_generation_(hardware_generation),
target_shape_({target_shape[0], target_shape[1]}),
default_tiling_(target_shape) {}

#define TPU_CHECK_OP(cond, msg) \
Expand Down Expand Up @@ -1709,7 +1711,12 @@ class VectorLayoutInferer {
"Only 32-bit truncation supported");
}
auto &layout = *some_layout;
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
// TPUv6 has good support for compute in 16-bit and cheap retiling between
// large 2nd minor and the default tiling, so we bias towards large tiles.
bool select_native =
(hardware_generation_ >= 6 && dst_ty.getElementTypeBitWidth() == 16)
? true
: allUsersRequireNativeTiling(op->getResult(0));
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
auto dst_layout = VectorLayout(
Expand Down Expand Up @@ -2064,29 +2071,36 @@ class VectorLayoutInferer {
default_tiling_[1]};
}

int hardware_generation_;
std::array<int64_t, 2> target_shape_;
std::array<int64_t, 2> default_tiling_;

// TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully
// remove the restriction that offsets must fall within the first tile.
bool force_first_tile_offsets_ = false;

// Address alignment requirement, counted in 32-bit increments.
static constexpr int64_t kVmemAlignment32 = 128;
// TODO(apaszke): This is not really native on newer generations of TPUs.
// Get rid of this temporary stopgap.
static constexpr int8_t kNativeBitwidth = 32;
};

struct InferVectorLayoutPass
: public impl::InferVectorLayoutPassBase<InferVectorLayoutPass> {
InferVectorLayoutPass(std::array<int64_t, 2> target_shape) {
InferVectorLayoutPass(int hardware_generation,
std::array<int64_t, 2> target_shape) {
this->sublane_count = target_shape[0];
this->lane_count = target_shape[1];
this->hardware_generation = hardware_generation;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
getOperation().emitOpError("hardware_generation must be set");
signalPassFailure();
return;
}
func::FuncOp func = getOperation();
VectorLayoutInferer run({sublane_count, lane_count});
VectorLayoutInferer run(hardware_generation, {sublane_count, lane_count});
if (run.infer(func).failed()) {
signalPassFailure();
}
Expand All @@ -2096,8 +2110,9 @@ struct InferVectorLayoutPass
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
std::array<int64_t, 2> target_shape) {
return std::make_unique<InferVectorLayoutPass>(target_shape);
int hardware_generation, std::array<int64_t, 2> target_shape) {
return std::make_unique<InferVectorLayoutPass>(hardware_generation,
target_shape);
}

} // namespace mlir::tpu
Loading