Skip to content

Commit

Permalink
Constant fusing support for BF16, F32, and S32.
Browse files Browse the repository at this point in the history
  • Loading branch information
Elfie Guo committed Apr 9, 2024
1 parent 4fe4157 commit 634d9a4
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 6 deletions.
1 change: 1 addition & 0 deletions cudnn_frontend
Submodule cudnn_frontend added at 4a9846
16 changes: 11 additions & 5 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ def _tf_repositories():
# curl -L <url> | sha256sum
# and update the sha256 with the result.

tf_http_archive(
# tf_http_archive(
# name = "cudnn_frontend_archive",
# build_file = "//third_party:cudnn_frontend.BUILD",
# patch_file = ["//third_party:cudnn_frontend_header_fix.patch"],
# sha256 = "1bb309af98fe9aad81b6a14fd52acbd6566aacfd322fc5803f9a1b77fc681a27",
# strip_prefix = "cudnn-frontend-1.2.1",
# urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.2.1.zip"),
# )

native.new_local_repository(
name = "cudnn_frontend_archive",
build_file = "//third_party:cudnn_frontend.BUILD",
patch_file = ["//third_party:cudnn_frontend_header_fix.patch"],
sha256 = "1bb309af98fe9aad81b6a14fd52acbd6566aacfd322fc5803f9a1b77fc681a27",
strip_prefix = "cudnn-frontend-1.2.1",
urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.2.1.zip"),
path = "./cudnn_frontend",
)

tf_http_archive(
Expand Down
60 changes: 59 additions & 1 deletion xla/service/gpu/cudnn_fusion_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <utility>
#include <vector>
#include <cuda_fp16.h>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
Expand Down Expand Up @@ -173,7 +174,7 @@ int FusionLevel(const HloInstruction& hlo) {
class GemmDimensionAdapter {
explicit GemmDimensionAdapter(const HloDotInstruction& dot,
TritonFusionAnalysis analysis)
: analysis_(std::move(analysis)), dot_(dot) {};
: analysis_(std::move(analysis)), dot_(dot){};

public:
const TritonFusionAnalysis analysis_;
Expand Down Expand Up @@ -311,6 +312,40 @@ class GemmDimensionAdapter {
const HloDotInstruction& dot_;
};

std::optional<std::shared_ptr<graph::Tensor_attributes>>
HandleConstantHloToCudnnGraph(const HloInstruction* hlo, graph::Graph& graph) {
CHECK(hlo->IsConstant()) << "HLO is not a constant: " << hlo->ToShortString();
if (!ShapeUtil::IsScalar(hlo->shape())) {
VLOG(3) << "Currently only support fusing scalar in the graph";
return std::nullopt;
}
PrimitiveType constant_type = hlo->shape().element_type();
switch (constant_type) {
case BF16:
return graph.tensor(__nv_bfloat16(
hlo->literal()
.data<primitive_util::PrimitiveTypeToNative<BF16>::type>()[0]));
case F32:
return graph.tensor(
hlo->literal()
.data<primitive_util::PrimitiveTypeToNative<F32>::type>()[0]);
case S32:
return graph.tensor(
hlo->literal()
.data<primitive_util::PrimitiveTypeToNative<S32>::type>()[0]);
// Enable F16 case once cuDNN F16 numerical issue is resolved:
// https://nvbugspro.nvidia.com/bug/4508897
// case F16:
// return graph.tensor(__half(
// hlo->literal()
// .data<primitive_util::PrimitiveTypeToNative<F16>::type>()[0]));
default:
VLOG(3) << "Unsupported constant type: "
<< PrimitiveType_Name(constant_type);
return std::nullopt;
}
}

// Traverses fusion computations and creates cuDNN graphs out of them.
absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
const HloFusionInstruction& fusion) {
Expand Down Expand Up @@ -371,6 +406,14 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
if (hlo->opcode() == HloOpcode::kParameter) {
CHECK(hlo_to_cudnn.contains(hlo));
continue;
} else if (FusionLevel(fusion) >= 2 &&
hlo->opcode() == HloOpcode::kConstant) {
if (const auto const_tensor = HandleConstantHloToCudnnGraph(hlo, graph);
const_tensor.has_value()) {
hlo_to_cudnn[hlo] = const_tensor.value();
} else {
return std::nullopt;
}
} else if (hlo->opcode() == HloOpcode::kReshape ||
hlo->opcode() == HloOpcode::kBitcast ||
hlo->opcode() == HloOpcode::kTranspose ||
Expand All @@ -395,6 +438,21 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
.set_compute_data_type(compute_dtype.value());
if (hlo->operand_count() == 1) {
hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
// Sets the dimensions for IDENTITY ops for cuDNN FE to infer its
// inputs' shapes.
if (mode.value() == fe::PointwiseMode_t::IDENTITY) {
const auto scope = adapter->analysis_.QueryInstructionScope(hlo);
std::vector<int64_t> dimensions;
std::vector<int64_t> strides;
if (!scope.has_value() ||
!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
strides)) {
VLOG(3) << "Unsupported hlo for querying dimensions: "
<< hlo->ToShortString();
} else {
hlo_to_cudnn[hlo]->set_dim(dimensions);
}
}
} else if (hlo->operand_count() == 2) {
hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
} else if (hlo->operand_count() == 3) {
Expand Down
26 changes: 26 additions & 0 deletions xla/service/gpu/fusions/cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,32 @@ ENTRY e {
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionLevel2Test, FuseConstantExecutesCorrectly) {
EXPECT_TRUE(RunAndCompare(R"(
fusion1 {
x = bf16[16,32] parameter(0)
y = bf16[32,16] parameter(1)
x_const = bf16[] constant(-1)
y_const = s32[] constant(-2)
x_const_bcast = bf16[16,32] broadcast(x_const), dimensions={}
y_const_bcast = s32[32,16] broadcast(y_const), dimensions={}
y_const_convert = bf16[32,16] convert(y_const_bcast)
x_add = bf16[16,32] minimum(x, x_const_bcast)
y_add = bf16[32,16] minimum(y, y_const_convert)
dot_a = f32[16,16] dot(x_add, y_add), lhs_contracting_dims={1}, rhs_contracting_dims={0}
c = f32[] constant(0)
c_bcast = f32[16,16] broadcast(c), dimensions={}
ROOT out = f32[16,16] maximum(dot_a, c_bcast)
}
ENTRY e {
p0 = bf16[16,32] parameter(0)
p1 = bf16[32,16] parameter(1)
ROOT _ = f32[16,16] fusion(p0, p1), kind=kCustom, calls=fusion1,
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
})",
ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-1}));
}

class CuDnnFusionLevel3Test : public CuDnnFusionExecutionTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down
11 changes: 11 additions & 0 deletions xla/service/gpu/triton_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,17 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion(
return absl::OkStatus();
}

std::optional<TritonFusionAnalysis::Scope> TritonFusionAnalysis::QueryInstructionScope(
const HloInstruction* hlo) const {
for (const Scope scope : {Scope::LHS, Scope::RHS, Scope::OUTPUT}) {
if (iter_specs_.at(scope).count(hlo) > 0) {
return scope;
}
}
LOG(WARNING) << "No scope for hlo: " << hlo->ToString();
return std::nullopt;
}

const TensorIterationSpec::DimIterationSpec* TritonFusionAnalysis::IterSpec(
const TritonFusionAnalysis::Scope scope, const HloInstruction* hlo,
const int dimension) const {
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/triton_fusion_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class TritonFusionAnalysis {
return parameters_.at(scope);
}

// Returns the given instruction's scope. Returns nullopt if no scope is
// found.
std::optional<Scope> QueryInstructionScope(const HloInstruction* hlo) const;

std::string ToString() const;

private:
Expand Down

0 comments on commit 634d9a4

Please sign in to comment.