Skip to content

Commit

Permalink
[XLA:GPU] Fix InsertOp's lowering when applying indices
Browse files Browse the repository at this point in the history
InsertOp is supposed to take the map results of the vector as inputs into its map and have no symbols as it doesn't create a loop for this.  Add a verifier to check that InsertOp is lowerable given these constraints.

PiperOrigin-RevId: 675113893
  • Loading branch information
vwbaker authored and Google-ML-Automation committed Sep 16, 2024
1 parent 4630020 commit ffe3055
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 14 deletions.
34 changes: 34 additions & 0 deletions xla/service/gpu/fusions/ir/tests/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,37 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>,
: (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
}

// -----

#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1],
is_simplified: false>
#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 mod 16 + s0, d1),
domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1],
is_simplified: false>

func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>,
%i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> {
// expected-error @+1 {{insert_op map must not have any symbols}}
%0 = xla_gpu.insert %input(%i, %j) into %output at #map1
: !xla_gpu.indexed_vector<32x64xf32, #map> -> tensor<32x64xf32>
func.return %0 : tensor<32x64xf32>
}

// -----

#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1],
is_simplified: false>
#map1 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 mod 16, d1, d2),
domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5],
is_simplified: false>

func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>,
%i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> {
// expected-error @+1 {{source map result count must equal insert_op's map's dimension count}}
%0 = xla_gpu.insert %input(%i, %j) into %output at #map1
: !xla_gpu.indexed_vector<32x64xf32, #map> -> tensor<32x64xf32>
func.return %0 : tensor<32x64xf32>
}
13 changes: 10 additions & 3 deletions xla/service/gpu/fusions/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,15 @@ func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32
#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1),
domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32],
is_simplified: false>
#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1),
domain: d0 in [0, 32], d1 in [0, 2],
is_simplified: false>

func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index,
%j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> {
%0 = xla_gpu.materialize @exp(%input) at #map(%i, %j)
: (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
%1 = xla_gpu.insert %0(%i, %j) into %output at #map1
%1 = xla_gpu.insert %0(%i, %j) into %output at #map2
: !xla_gpu.indexed_vector<32x64xf32, #map1> -> tensor<32x64xf32>
func.return %1 : tensor<32x64xf32>
}
Expand All @@ -175,6 +178,10 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index,
// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]
// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1)
// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]
// CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1)
// CHECK-SAME: d0 in [0, 32], d1 in [0, 2],
// CHECK-LABEL: @materialize_and_insert
// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at #[[$MAP]](%{{.*}}, %{{.*}})
// CHECK: xla_gpu.insert %[[MATERIALIZED]](%{{.*}}, %{{.*}}) into %{{.*}} at #[[$MAP1]]
// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at
// CHECK-SAME: #[[$MAP]](%{{.*}}, %{{.*}})
// CHECK: xla_gpu.insert %[[MATERIALIZED]](%{{.*}}, %{{.*}}) into
// CHECK-SAME: at #[[$MAP2]] : <32x64xf32, #[[$MAP1]]>
17 changes: 17 additions & 0 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,23 @@ LogicalResult MaterializeOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//

LogicalResult InsertOp::verify() {
if (!getMap().getRangeVars().empty()) {
return emitOpError() << "insert_op map must not have any symbols";
}
int64_t vector_map_num_results =
getSource().getType().getIndexingMapAttr().getNumResults();
if (vector_map_num_results != getMap().getDimVars().size()) {
return emitOpError() << "source map result count must equal insert_op's "
"map's dimension count";
}
return success();
}

} // namespace gpu
} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def XLAGPU_InsertOp : XLAGPU_Op<"insert", [TypesMatchWith<
AnyRankedTensor:$dest,
XLAGPU_IndexingMapAttr:$map);
let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
let assemblyFormat = [{
$source `(` $indices `)` `into` $dest `at` $map attr-dict `:` type($source) `->` type($result)
}];
Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {
// indexed_vector index -> tensor index.
// We get indexed_vector index by using its encoding map (source_map).
// So we loop over indexed_vector encoding map and use the results as the
// symbols for InsertOp's map in order to get the final tensor index.
// dimensions for InsertOp's map in order to get the final tensor index.
auto source_map = op.getSource().getType().getIndexingMapAttr();
auto loop = b.create<LoopOp>(
source_map, op.getIndices(), ValueRange{op.getDest()},
Expand All @@ -321,10 +321,10 @@ struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {
auto scalar =
b.create<mlir::vector::ExtractOp>(convert, vector_offset);
auto tensor_indices = b.create<ApplyIndexingOp>(
op.getIndices(), map_results, op.getMap().getIndexingMap());
map_results, ValueRange(), op.getMap().getIndexingMap());
Value new_tensor = b.create<mlir::tensor::InsertOp>(
scalar.getResult(), iter_args.back(),
tensor_indices->getResults());
tensor_indices.getResults());
b.create<YieldOp>(new_tensor);
});
rewriter.replaceOp(op, loop->getResults());
Expand Down
19 changes: 11 additions & 8 deletions xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index)
#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1],
is_simplified: false>
#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1],
#map1 = #xla_gpu.indexing_map<(d0, d1) -> (d0 mod 16, d1),
domain: d0 in [0, 32], d1 in [0, 2],
is_simplified: false>

func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>,
Expand All @@ -167,7 +167,7 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>,
func.return %0 : tensor<32x64xf32>
}
// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1)
// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1)
// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 mod 16, d1)

// CHECK: @insert(%[[INPUT:.*]]: !xla_gpu.indexed_vector<32x64xf32, #[[$MAP]]>,
// CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index,
Expand All @@ -177,9 +177,12 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>,
// CHECK-SAME: (%[[MAP_RESULT1:.*]], %[[MAP_RESULT2:.*]]) in #[[$MAP]]
// CHECK-SAME: iter_args(%[[TENSOR:.*]] = %[[OUTPUT]])

// CHECK: %[[SCALAR:.*]] = vector.extract %{{.*}}[%[[S0]], %[[S1]]] : f32 from vector<2x2xf32>
// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing #[[$MAP1]](%[[I]], %[[J]])[%[[MAP_RESULT1]], %[[MAP_RESULT2]]]
// CHECK: %[[NEW_TENSOR:.*]] = tensor.insert %[[SCALAR]] into %[[TENSOR]][%[[MAP1_RESULT]]#0, %[[MAP1_RESULT]]#1]
// CHECK: %[[SCALAR:.*]] = vector.extract %{{.*}}[%[[S0]], %[[S1]]]
// CHECK-SAME: : f32 from vector<2x2xf32>
// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing
// CHECK-SAME: #[[$MAP1]](%[[MAP_RESULT1]], %[[MAP_RESULT2]])
// CHECK: %[[NEW_TENSOR:.*]] = tensor.insert %[[SCALAR]]
// CHECK-SAME: into %[[TENSOR]][%[[MAP1_RESULT]]#0, %[[MAP1_RESULT]]#1]
// CHECK: xla_gpu.yield %[[NEW_TENSOR]]

// -----
Expand All @@ -192,8 +195,8 @@ func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32
#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1],
is_simplified: false>
#map2 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1),
domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1],
#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1),
domain: d0 in [0, 32], d1 in [0, 2],
is_simplified: false>

func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index,
Expand Down

0 comments on commit ffe3055

Please sign in to comment.