Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: simplify operations with sign #336

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Feb 11, 2025

Trying to simplify

module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<8x4xf32>, %arg1: tensor<16xf32>) -> tensor<8x4xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<32x4x8xf32>
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [2, 1] : (tensor<8x4xf32>) -> tensor<16x4x8xf32>
    %1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<16xf32>) -> tensor<16x4x8xf32>
    %2 = stablehlo.multiply %0, %1 : tensor<16x4x8xf32>
    %3 = stablehlo.sine %2 : tensor<16x4x8xf32>
    %4 = stablehlo.cosine %2 : tensor<16x4x8xf32>
    %5 = stablehlo.concatenate %3, %4, dim = 0 : (tensor<16x4x8xf32>, tensor<16x4x8xf32>) -> tensor<32x4x8xf32>
    %6 = stablehlo.abs %5 : tensor<32x4x8xf32>
    %7 = stablehlo.add %6, %6 : tensor<32x4x8xf32>
    %8 = stablehlo.compare  GE, %5, %cst_0 : (tensor<32x4x8xf32>, tensor<32x4x8xf32>) -> tensor<32x4x8xi1>
    %9 = stablehlo.negate %7 : tensor<32x4x8xf32>
    %10 = stablehlo.select %8, %7, %9 : tensor<32x4x8xi1>, tensor<32x4x8xf32>
    %11 = stablehlo.slice %10 [0:16, 0:4, 0:8] : (tensor<32x4x8xf32>) -> tensor<16x4x8xf32>
    %12 = stablehlo.slice %10 [16:32, 0:4, 0:8] : (tensor<32x4x8xf32>) -> tensor<16x4x8xf32>
    %13 = stablehlo.negate %3 : tensor<16x4x8xf32>
    %14 = stablehlo.multiply %12, %13 : tensor<16x4x8xf32>
    %15 = stablehlo.multiply %11, %4 : tensor<16x4x8xf32>
    %16 = stablehlo.add %14, %15 : tensor<16x4x8xf32>
    %17 = stablehlo.multiply %16, %1 : tensor<16x4x8xf32>
    %18 = stablehlo.reduce(%17 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<16x4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
    %19 = stablehlo.reshape %18 : (tensor<4x8xf32>) -> tensor<8x4xf32>
    return %19 : tensor<8x4xf32>
  }
}

most of the internal operations are actually no-ops (this is from a chunk inside a transformer so kind of a common pattern). After these passes

module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<8x4xf32>, %arg1: tensor<16xf32>) -> tensor<8x4xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [2, 1] : (tensor<8x4xf32>) -> tensor<16x4x8xf32>
    %1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<16xf32>) -> tensor<16x4x8xf32>
    %2 = stablehlo.multiply %0, %1 : tensor<16x4x8xf32>
    %3 = stablehlo.sine %2 : tensor<16x4x8xf32>
    %4 = stablehlo.cosine %2 : tensor<16x4x8xf32>
    %5 = stablehlo.concatenate %3, %4, dim = 0 : (tensor<16x4x8xf32>, tensor<16x4x8xf32>) -> tensor<32x4x8xf32>
    %6 = stablehlo.multiply %5, %5 : tensor<32x4x8xf32>
    %7 = stablehlo.negate %6 : tensor<32x4x8xf32>
    %8 = stablehlo.slice %7 [0:16, 0:4, 0:8] : (tensor<32x4x8xf32>) -> tensor<16x4x8xf32>
    %9 = stablehlo.slice %7 [16:32, 0:4, 0:8] : (tensor<32x4x8xf32>) -> tensor<16x4x8xf32>
    %10 = stablehlo.multiply %9, %3 : tensor<16x4x8xf32>
    %11 = stablehlo.negate %10 : tensor<16x4x8xf32>
    %12 = stablehlo.multiply %8, %4 : tensor<16x4x8xf32>
    %13 = stablehlo.add %11, %12 : tensor<16x4x8xf32>
    %14 = stablehlo.multiply %13, %1 : tensor<16x4x8xf32>
    %15 = stablehlo.reduce(%14 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<16x4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
    %16 = stablehlo.reshape %15 : (tensor<4x8xf32>) -> tensor<8x4xf32>
    return %16 : tensor<8x4xf32>
  }
}

@avik-pal
Copy link
Collaborator Author

Some of these need to be qualified using no_nan

// (select (x > 0) z (neg z)) -> (mul (sign x) z)
// (select (x >= 0) z (neg z)) -> (mul (sign x) z)
// (select (x > 0) (neg z) z) -> (mul (sign x) (neg z))
// (select (x >= 0) (neg z) z) -> (mul (sign x) (neg z))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this actually simpler/faster?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one mostly enables the other optimizations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a select is usually faster than a mul, so this in isolation makes things locally worse. Is it feasible for the downstream operations to work on select style forms?

};

// (mul (sign x) (abs x)) -> x
// (mul (abs x) (sign x)) -> x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems reasonable, can you split to do this individually?

};

// (mul (neg x) (neg y)) -> (mul x y)
// (mul (neg x) y) -> (neg (mul x y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top one is always good, the one negation variations we should separate since there's a separate question of whether we want to propagate them up or down (e.g. if we had mul (neg x), constant) we'd want to do mul x (-constant)


// This pattern only does partially the following. We rely on transforming the op to a
// pattern which further uses the above pattern.
// (mul (sign x) (add (abs x) (abs x))) -> (mul x x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

longer term i feel like this merits a broader sign analysis (alongside perhaps a transpose analysis)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants