diff --git a/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc index cd428a5bfacc5..f0197709d8a83 100644 --- a/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc +++ b/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc @@ -245,12 +245,16 @@ struct RewriteXlaGpuLoop : mlir::OpRewritePattern { }; mlir::VectorType getThreadLevelVectorType(IndexedVectorType indexed_vector) { + auto data_type = indexed_vector.getElementType(); SmallVector vector_dims; + if (auto complex = mlir::dyn_cast(data_type)) { + vector_dims.push_back(2); + data_type = complex.getElementType(); + } IndexingMap map = indexed_vector.getIndexingMapAttr().getIndexingMap(); for (auto bound : map.GetSymbolBounds()) { vector_dims.push_back(bound.GetLoopTripCount()); } - auto data_type = indexed_vector.getElementType(); return mlir::VectorType::get(vector_dims, data_type); } @@ -260,9 +264,12 @@ struct RewriteMaterialize : mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( MaterializeOp op, mlir::PatternRewriter& rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto i0 = b.create(0); + auto i1 = b.create(1); - auto data_type = op.getResult().getType().getElementType(); auto vec_type = getThreadLevelVectorType(op.getResult().getType()); + auto maybe_complex_data_type = op.getResult().getType().getElementType(); + auto data_type = vec_type.getElementType(); Value init_vec; if (mlir::isa(data_type)) { init_vec = b.create(mlir::DenseElementsAttr::get( @@ -280,13 +287,24 @@ struct RewriteMaterialize : mlir::OpRewritePattern { ValueRange iter_args) { auto args = SmallVector(op.getInput()); args.insert(args.end(), map_results.begin(), map_results.end()); - SmallVector types{data_type}; - auto call = - b.create(op.getCalleeAttr(), ValueRange{args}, types); + SmallVector types{maybe_complex_data_type}; + auto call_result = + b.create(op.getCalleeAttr(), ValueRange{args}, types) + .getResult(0); SmallVector offset(ivs); auto old_vec = iter_args.back(); - Value new_vec = b.create(call.getResult(0), - old_vec, offset); + Value new_vec; + if (mlir::isa(call_result.getType())) { + auto real = b.create(call_result); + auto imag = b.create(call_result); + offset.insert(offset.begin(), i0.getResult()); + new_vec = b.create(real, old_vec, offset); + offset.front() = i1.getResult(); + new_vec = b.create(imag, new_vec, offset); + } else { + new_vec = + b.create(call_result, old_vec, offset); + } b.create(new_vec); }); auto convert = b.create( @@ -303,6 +321,8 @@ struct RewriteInsert : mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( InsertOp op, mlir::PatternRewriter& rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto i0 = b.create(0); + auto i1 = b.create(1); auto convert = b.create( getThreadLevelVectorType(op.getSource().getType()), op.getSource()) @@ -318,13 +338,25 @@ struct RewriteInsert : mlir::OpRewritePattern { [&](OpBuilder&, Location, ValueRange ivs, ValueRange map_results, ValueRange iter_args) { SmallVector vector_offset(ivs); - auto scalar = - b.create(convert, vector_offset); + Value scalar; + if (auto complex = mlir::dyn_cast( + op.getSource().getType().getElementType())) { + vector_offset.insert(vector_offset.begin(), i0.getResult()); + auto real = + b.create(convert, vector_offset); + vector_offset.front() = i1.getResult(); + auto imag = + b.create(convert, vector_offset); + scalar = b.create(complex, real, imag) + .getResult(); + } else { + scalar = b.create(convert, vector_offset) + .getResult(); + } auto tensor_indices = b.create( map_results, ValueRange(), op.getMap().getIndexingMap()); Value new_tensor = b.create( - scalar.getResult(), iter_args.back(), - tensor_indices.getResults()); + scalar, iter_args.back(), tensor_indices.getResults()); b.create(new_tensor); }); rewriter.replaceOp(op, loop->getResults()); diff --git a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index 55b704e85d8ef..8cc6b652efecc 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -208,3 +208,70 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.return %1 : tensor<32x64xf32> } // CHECK-NOT: unrealized_conversion_cast + +// ----- + +func.func private @exp(%p0: tensor<32x64xcomplex>, %i: index, %j: index) -> complex + +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), + domain: d0 in [0, 32], d1 in [0, 8], + s0 in [0, 2], s1 in [0, 3], is_simplified: false> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), + domain: d0 in [0, 32], d1 in [0, 2], + s0 in [0, 2], s1 in [0, 3], is_simplified: false> +func.func @materialize_complex( + %input: tensor<32x64xcomplex>, + %output: tensor<32x64xcomplex>, + %d0: index, + %d1: index) -> !xla_gpu.indexed_vector<32x3x4xcomplex, #map1> { + + %0 = xla_gpu.materialize @exp(%input) at #map(%d0, %d1) + : (tensor<32x64xcomplex>) + -> !xla_gpu.indexed_vector<32x3x4xcomplex, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x3x4xcomplex, #map1> +} + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: xla_gpu.loop ({{.*}})[%[[I:.*]], %[[J:.*]]] +// CHECK-SAME: iter_args(%[[ITER:.*]] = {{.*}}) +// CHECK: %[[PURE_CALL:.*]] = xla_gpu.pure_call +// CHECK-SAME: complex +// CHECK: %[[REAL:.*]] = complex.re %[[PURE_CALL]] +// CHECK: %[[IMAG:.*]] = complex.im %[[PURE_CALL]] +// CHECK: %[[TEMP:.*]] = vector.insert %[[REAL]], %[[ITER]] [%[[C0]], %[[I]], %[[J]]] +// CHECK: %[[FINAL:.*]] = vector.insert %[[IMAG]], %[[TEMP]] [%[[C1]], %[[I]], %[[J]]] +// CHECK: xla_gpu.yield %[[FINAL]] : vector<2x3x4xf32> + +// ----- + +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), + domain: d0 in [0, 32], d1 in [0, 2], + s0 in [0, 2], s1 in [0, 3], is_simplified: false> +#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), + domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false> +func.func @insert_complex( + %input: !xla_gpu.indexed_vector<32x3x4xcomplex, #map1>, + %output: tensor<32x64xcomplex>, + %d0: index, + %d1: index) -> tensor<32x64xcomplex> { + + %1 = xla_gpu.insert %input(%d0, %d1) into %output at #map2 + : !xla_gpu.indexed_vector<32x3x4xcomplex, #map1> + -> tensor<32x64xcomplex> + func.return %1 : tensor<32x64xcomplex> +} + +// CHECK-LABEL: @insert_complex +// CHECK-SAME: %[[INPUT:.*]]: !xla_gpu.indexed_vector<32x3x4xcomplex +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[VECTOR:.*]] = builtin.unrealized_conversion_cast %[[INPUT]] +// CHECK-SAME: to vector<2x3x4xf32> +// CHECK: xla_gpu.loop ({{.*}})[%[[I:.*]], %[[J:.*]]] +// CHECK-SAME: iter_args(%[[ITER:.*]] = {{.*}}) +// CHECK: %[[REAL:.*]] = vector.extract %[[VECTOR]][%[[C0]], %[[I]], %[[J]]] +// CHECK: %[[IMAG:.*]] = vector.extract %[[VECTOR]][%[[C1]], %[[I]], %[[J]]] +// CHECK: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]] +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[COMPLEX]] into %[[ITER]] +// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex> \ No newline at end of file