Skip to content

Commit

Permalink
[XLA:GPU] Support complex numbers in materialize & insert op lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675123278
  • Loading branch information
vwbaker authored and Google-ML-Automation committed Sep 16, 2024
1 parent 58cce1a commit d8a0f51
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 11 deletions.
54 changes: 43 additions & 11 deletions xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,16 @@ struct RewriteXlaGpuLoop : mlir::OpRewritePattern<LoopOp> {
};

mlir::VectorType getThreadLevelVectorType(IndexedVectorType indexed_vector) {
auto data_type = indexed_vector.getElementType();
SmallVector<int64_t> vector_dims;
if (auto complex = mlir::dyn_cast<mlir::ComplexType>(data_type)) {
vector_dims.push_back(2);
data_type = complex.getElementType();
}
IndexingMap map = indexed_vector.getIndexingMapAttr().getIndexingMap();
for (auto bound : map.GetSymbolBounds()) {
vector_dims.push_back(bound.GetLoopTripCount());
}
auto data_type = indexed_vector.getElementType();
return mlir::VectorType::get(vector_dims, data_type);
}

Expand All @@ -260,9 +264,12 @@ struct RewriteMaterialize : mlir::OpRewritePattern<MaterializeOp> {
mlir::LogicalResult matchAndRewrite(
MaterializeOp op, mlir::PatternRewriter& rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto i0 = b.create<mlir::arith::ConstantIndexOp>(0);
auto i1 = b.create<mlir::arith::ConstantIndexOp>(1);

auto data_type = op.getResult().getType().getElementType();
auto vec_type = getThreadLevelVectorType(op.getResult().getType());
auto maybe_complex_data_type = op.getResult().getType().getElementType();
auto data_type = vec_type.getElementType();
Value init_vec;
if (mlir::isa<mlir::IntegerType>(data_type)) {
init_vec = b.create<mlir::arith::ConstantOp>(mlir::DenseElementsAttr::get(
Expand All @@ -280,13 +287,24 @@ struct RewriteMaterialize : mlir::OpRewritePattern<MaterializeOp> {
ValueRange iter_args) {
auto args = SmallVector<Value, 4>(op.getInput());
args.insert(args.end(), map_results.begin(), map_results.end());
SmallVector<mlir::Type, 1> types{data_type};
auto call =
b.create<PureCallOp>(op.getCalleeAttr(), ValueRange{args}, types);
SmallVector<mlir::Type, 1> types{maybe_complex_data_type};
auto call_result =
b.create<PureCallOp>(op.getCalleeAttr(), ValueRange{args}, types)
.getResult(0);
SmallVector<mlir::OpFoldResult> offset(ivs);
auto old_vec = iter_args.back();
Value new_vec = b.create<mlir::vector::InsertOp>(call.getResult(0),
old_vec, offset);
Value new_vec;
if (mlir::isa<mlir::ComplexType>(call_result.getType())) {
auto real = b.create<mlir::complex::ReOp>(call_result);
auto imag = b.create<mlir::complex::ImOp>(call_result);
offset.insert(offset.begin(), i0.getResult());
new_vec = b.create<mlir::vector::InsertOp>(real, old_vec, offset);
offset.front() = i1.getResult();
new_vec = b.create<mlir::vector::InsertOp>(imag, new_vec, offset);
} else {
new_vec =
b.create<mlir::vector::InsertOp>(call_result, old_vec, offset);
}
b.create<YieldOp>(new_vec);
});
auto convert = b.create<mlir::UnrealizedConversionCastOp>(
Expand All @@ -303,6 +321,8 @@ struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {
mlir::LogicalResult matchAndRewrite(
InsertOp op, mlir::PatternRewriter& rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto i0 = b.create<mlir::arith::ConstantIndexOp>(0);
auto i1 = b.create<mlir::arith::ConstantIndexOp>(1);
auto convert =
b.create<mlir::UnrealizedConversionCastOp>(
getThreadLevelVectorType(op.getSource().getType()), op.getSource())
Expand All @@ -318,13 +338,25 @@ struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {
[&](OpBuilder&, Location, ValueRange ivs, ValueRange map_results,
ValueRange iter_args) {
SmallVector<mlir::OpFoldResult> vector_offset(ivs);
auto scalar =
b.create<mlir::vector::ExtractOp>(convert, vector_offset);
Value scalar;
if (auto complex = mlir::dyn_cast<mlir::ComplexType>(
op.getSource().getType().getElementType())) {
vector_offset.insert(vector_offset.begin(), i0.getResult());
auto real =
b.create<mlir::vector::ExtractOp>(convert, vector_offset);
vector_offset.front() = i1.getResult();
auto imag =
b.create<mlir::vector::ExtractOp>(convert, vector_offset);
scalar = b.create<mlir::complex::CreateOp>(complex, real, imag)
.getResult();
} else {
scalar = b.create<mlir::vector::ExtractOp>(convert, vector_offset)
.getResult();
}
auto tensor_indices = b.create<ApplyIndexingOp>(
map_results, ValueRange(), op.getMap().getIndexingMap());
Value new_tensor = b.create<mlir::tensor::InsertOp>(
scalar.getResult(), iter_args.back(),
tensor_indices.getResults());
scalar, iter_args.back(), tensor_indices.getResults());
b.create<YieldOp>(new_tensor);
});
rewriter.replaceOp(op, loop->getResults());
Expand Down
67 changes: 67 additions & 0 deletions xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,70 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index,
func.return %1 : tensor<32x64xf32>
}
// CHECK-NOT: unrealized_conversion_cast

// -----

func.func private @exp(%p0: tensor<32x64xcomplex<f32>>, %i: index, %j: index) -> complex<f32>

#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 8],
s0 in [0, 2], s1 in [0, 3], is_simplified: false>
#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 2],
s0 in [0, 2], s1 in [0, 3], is_simplified: false>
func.func @materialize_complex(
%input: tensor<32x64xcomplex<f32>>,
%output: tensor<32x64xcomplex<f32>>,
%d0: index,
%d1: index) -> !xla_gpu.indexed_vector<32x3x4xcomplex<f32>, #map1> {

%0 = xla_gpu.materialize @exp(%input) at #map(%d0, %d1)
: (tensor<32x64xcomplex<f32>>)
-> !xla_gpu.indexed_vector<32x3x4xcomplex<f32>, #map1>
func.return %0 : !xla_gpu.indexed_vector<32x3x4xcomplex<f32>, #map1>
}

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: xla_gpu.loop ({{.*}})[%[[I:.*]], %[[J:.*]]]
// CHECK-SAME: iter_args(%[[ITER:.*]] = {{.*}})
// CHECK: %[[PURE_CALL:.*]] = xla_gpu.pure_call
// CHECK-SAME: complex<f32>
// CHECK: %[[REAL:.*]] = complex.re %[[PURE_CALL]]
// CHECK: %[[IMAG:.*]] = complex.im %[[PURE_CALL]]
// CHECK: %[[TEMP:.*]] = vector.insert %[[REAL]], %[[ITER]] [%[[C0]], %[[I]], %[[J]]]
// CHECK: %[[FINAL:.*]] = vector.insert %[[IMAG]], %[[TEMP]] [%[[C1]], %[[I]], %[[J]]]
// CHECK: xla_gpu.yield %[[FINAL]] : vector<2x3x4xf32>

// -----

#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1),
domain: d0 in [0, 32], d1 in [0, 2],
s0 in [0, 2], s1 in [0, 3], is_simplified: false>
#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1),
domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false>
func.func @insert_complex(
%input: !xla_gpu.indexed_vector<32x3x4xcomplex<f32>, #map1>,
%output: tensor<32x64xcomplex<f32>>,
%d0: index,
%d1: index) -> tensor<32x64xcomplex<f32>> {

%1 = xla_gpu.insert %input(%d0, %d1) into %output at #map2
: !xla_gpu.indexed_vector<32x3x4xcomplex<f32>, #map1>
-> tensor<32x64xcomplex<f32>>
func.return %1 : tensor<32x64xcomplex<f32>>
}

// CHECK-LABEL: @insert_complex
// CHECK-SAME: %[[INPUT:.*]]: !xla_gpu.indexed_vector<32x3x4xcomplex<f32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VECTOR:.*]] = builtin.unrealized_conversion_cast %[[INPUT]]
// CHECK-SAME: to vector<2x3x4xf32>
// CHECK: xla_gpu.loop ({{.*}})[%[[I:.*]], %[[J:.*]]]
// CHECK-SAME: iter_args(%[[ITER:.*]] = {{.*}})
// CHECK: %[[REAL:.*]] = vector.extract %[[VECTOR]][%[[C0]], %[[I]], %[[J]]]
// CHECK: %[[IMAG:.*]] = vector.extract %[[VECTOR]][%[[C1]], %[[I]], %[[J]]]
// CHECK: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]]
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[COMPLEX]] into %[[ITER]]
// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex<f32>>

0 comments on commit d8a0f51

Please sign in to comment.