Skip to content

Commit a48de7e

Browse files
committed
Add einops-style rearrange to keras.ops.einops
1 parent ab3c8f5 commit a48de7e

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from keras.src.ops.core import unstack
3131
from keras.src.ops.core import vectorized_map
3232
from keras.src.ops.core import while_loop
33+
from keras.src.ops.einops import rearrange
3334
from keras.src.ops.linalg import cholesky
3435
from keras.src.ops.linalg import det
3536
from keras.src.ops.linalg import eig

keras/api/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from keras.src.ops.core import unstack
3131
from keras.src.ops.core import vectorized_map
3232
from keras.src.ops.core import while_loop
33+
from keras.src.ops.einops import rearrange
3334
from keras.src.ops.linalg import cholesky
3435
from keras.src.ops.linalg import det
3536
from keras.src.ops.linalg import eig

keras/src/ops/einops.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import re
2+
3+
from keras.src.api_export import keras_export
4+
from keras.src.ops.core import shape
5+
from keras.src.ops.numpy import prod
6+
from keras.src.ops.numpy import reshape
7+
from keras.src.ops.numpy import transpose
8+
9+
10+
def __create_axes_map(axes, input_shape, axes_lengths):
11+
axes_map = {}
12+
13+
for axis, dim in zip(axes, input_shape):
14+
# Check for grouped axes pattern, e.g., "(h1 h)"
15+
grouped_axes = re.match(r"\(([\w\s]+)\)", axis)
16+
17+
if grouped_axes:
18+
inner_axes = grouped_axes.group(1).split()
19+
known_axes = [a for a in inner_axes if a in axes_lengths]
20+
inferred_axes = [a for a in inner_axes if a not in axes_lengths]
21+
22+
if inferred_axes:
23+
inferred_axis = inferred_axes[0]
24+
known_product = prod([axes_lengths[a] for a in known_axes])
25+
axes_lengths[inferred_axis] = dim // known_product
26+
27+
axes_map.update({a: axes_lengths[a] for a in inner_axes})
28+
else:
29+
axes_map[axis] = dim
30+
31+
return axes_map
32+
33+
34+
def __create_grouped_axes(axes):
35+
grouped_output_axes = []
36+
for axis in axes:
37+
grouped_axes = re.match(r"\(([\w\s]+)\)", axis)
38+
39+
if grouped_axes:
40+
inner_axes = grouped_axes.group(1).split()
41+
grouped_output_axes.append(inner_axes)
42+
else:
43+
grouped_output_axes.append([axis])
44+
45+
return grouped_output_axes
46+
47+
48+
def __flatten_group(axes):
49+
return [x for xs in axes for x in xs]
50+
51+
52+
def __get_transpose_order(from_shape, to_shape):
53+
flattened_from_shape = __flatten_group(__create_grouped_axes(from_shape))
54+
55+
return [flattened_from_shape.index(dim) for dim in to_shape]
56+
57+
58+
def __compute_output_shape(axes_map, grouped_axes):
59+
output_shape = []
60+
for group in grouped_axes:
61+
size = 1
62+
for axis in group:
63+
size *= axes_map[axis]
64+
output_shape.append(size)
65+
66+
return tuple(output_shape)
67+
68+
69+
def __compute_decomposed_shape(input_axes, axes_lengths, axes_map):
70+
reshaped_input_axes = []
71+
reshaped_sizes = []
72+
73+
for axis in input_axes:
74+
if "(" in axis: # Decomposed axis
75+
inner_axes = re.findall(r"\w+", axis)
76+
sizes = [axes_lengths[a] for a in inner_axes]
77+
reshaped_input_axes.extend(inner_axes)
78+
reshaped_sizes.extend(sizes)
79+
else:
80+
reshaped_input_axes.append(axis)
81+
reshaped_sizes.append(axes_map[axis])
82+
83+
return reshaped_sizes
84+
85+
86+
@keras_export("keras.ops.rearrange")
87+
def rearrange(tensor, pattern, **axes_lengths):
88+
"""
89+
Rearranges the axes of a Keras tensor according to a specified pattern.
90+
91+
Args:
92+
tensor (Tensor): Input Keras tensor.
93+
pattern (str): String describing the rearrangement in einops notation.
94+
**axes_lengths: Keyword arguments specifying lengths of axes
95+
when axes decomposition is used.
96+
97+
Returns:
98+
Tensor: A Keras tensor with rearranged axes.
99+
100+
Follows the logic:
101+
1. If decomposition is needed:
102+
- Reshape to match dimension decomposition.
103+
2. Permute axes to match the form of the output.
104+
3. Reshape to match the desired output shape.
105+
"""
106+
107+
# Split the input and output patterns
108+
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern)
109+
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern)
110+
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern)
111+
input_shape = shape(tensor)
112+
113+
# Create axes map, and flattened output group
114+
axes_map = __create_axes_map(input_axes, input_shape, axes_lengths)
115+
grouped_output_axes = __create_grouped_axes(output_axes)
116+
flattened_output_axes = __flatten_group(grouped_output_axes)
117+
118+
# 1. Axes decomposition
119+
decomposed_shapes = __compute_decomposed_shape(
120+
input_axes, axes_lengths, axes_map
121+
)
122+
if decomposed_shapes != tensor.shape:
123+
tensor = reshape(tensor, decomposed_shapes)
124+
125+
# 2. Transpose to match target shape
126+
permute_order = __get_transpose_order(input_axes, flattened_output_axes)
127+
tensor = transpose(tensor, permute_order)
128+
129+
# 3. Reshape to final target shape
130+
output_shape = __compute_output_shape(axes_map, grouped_output_axes)
131+
tensor = reshape(tensor, output_shape)
132+
133+
return tensor

0 commit comments

Comments
 (0)