Skip to content

Commit

Permalink
Remove _dot_general_transpose_l[r]hs
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Aug 21, 2024
1 parent a25f7fc commit 6e4fc6d
Showing 1 changed file with 57 additions and 14 deletions.
71 changes: 57 additions & 14 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import dataclasses
import itertools
import numpy as np
import warnings
from functools import partial
Expand All @@ -25,7 +26,7 @@
from jax import numpy as jnp
from jax._src import core
from jax._src import dtypes
from jax._src.lax.lax import _dot_general_transpose_lhs, _dot_general_transpose_rhs
from jax._src.typing import DTypeLike

try:
from jax._src import earray
Expand Down Expand Up @@ -167,9 +168,9 @@ def compute_amax_history(x, amax_history):
def quantize_and_update(
x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False
):
is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32
is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32)
# convert fm32->f32 so we can do math
if is_fm32:
if is_fmax32:
amax_history = lax.convert_element_type(amax_history, jnp.float32)
scale = lax.convert_element_type(scale, jnp.float32)

Expand Down Expand Up @@ -403,29 +404,71 @@ def q_dot_dq_bwd(

q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type)

grad_lhs = _dot_general_transpose_lhs(
def _dot_general_transpose_and_dequantize(
g, x, y, scale, *, dimension_numbers, precision,
preferred_element_type=DTypeLike | None, swap_ans=False
) :
def _remaining(original, *removed_lists):
removed = set(itertools.chain(*removed_lists))
return [i for i in original if i not in removed]

def _ranges_like(*xs):
start = 0
for x in xs:
x_len = len(x)
yield range(start, start + x_len)
start += x_len

(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim

x_kept = _remaining(range(x_ndim), x_contract, x_batch)
y_kept = _remaining(range(np.ndim(y)), y_contract, y_batch)

if swap_ans:
ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept)
else:
ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept)

dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)

# Perform dot_general and transpose operations
result = jax._src.lax.lax.transpose(
jax._src.lax.lax.dot_general(g, y, dims, precision=precision,
preferred_element_type=preferred_element_type),
tuple(out_axes)
)

# Convert element type if necessary
if result.dtype != x.aval.dtype:
result = _convert_element_type(result, x.aval.dtype, x.aval.weak_type)

result = dequantize(result, preferred_element_type, scale)
return result

grad_lhs = _dot_general_transpose_and_dequantize(
q_g,
lhs,
q_rhs,
new_lhs_scale * new_out_grad_scale,
dimension_numbers=dimension_numbers,
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)
grad_lhs = dequantize(
grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale
)

grad_rhs = _dot_general_transpose_rhs(
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
grad_rhs = _dot_general_transpose_and_dequantize(
q_g,
q_lhs,
rhs,
dimension_numbers=dimension_numbers,
q_lhs,
new_rhs_scale * new_out_grad_scale,
dimension_numbers=swapped_dimension_numbers,
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)

grad_rhs = dequantize(
grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale
swap_ans=True,
)

return (
Expand Down

0 comments on commit 6e4fc6d

Please sign in to comment.