Skip to content

Commit

Permalink
Merge pull request #20 from gabrieledangeli/master
Browse files Browse the repository at this point in the history
Implementation of PatchAmplitudeEmbedding
  • Loading branch information
emapuljak authored Jan 8, 2025
2 parents cf77367 + 6f60372 commit a81e7fb
Showing 1 changed file with 158 additions and 5 deletions.
163 changes: 158 additions & 5 deletions tn4ml/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Embedding:
Attributes
----------
dype: :class:`numpy.dype`
dtype: :class:`numpy.dtype`
Data Type
"""
def __init__(self, dtype=onp.float32):
Expand All @@ -44,7 +44,7 @@ class ComplexEmbedding:
Attributes
----------
dype: :class:`numpy.dype`
dtype: :class:`numpy.dtype`
Data Type
"""
def __init__(self, dtype=onp.float32):
Expand Down Expand Up @@ -77,7 +77,7 @@ class StateVectorToMPSEmbedding:
A class to convert a statevector into a Matrix Product State (MPS).
"""
def __init__(self, dtype=onp.float32, max_bond=None):
self.dype = dtype
self.dtype = dtype
self.max_bond = max_bond

@property
Expand Down Expand Up @@ -550,7 +550,7 @@ def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray:
# Number of pixels (N = 16 for a 4x4 image)
N = len(x)
# Number of address qubits is log2(N) = 4
n_address_qubits = int(np.log2(N))
n_address_qubits = int(np.ceil(np.log2(N)))
# One color qubit
n_color_qubit = 1
# Total number of qubits = address qubits + 1 color qubit
Expand Down Expand Up @@ -672,6 +672,159 @@ def __call__(self, x: jnp.ndarray) -> qtn.MatrixProductState:
self.mps = qtn.MatrixProductState(new_arrays, shape='lrp')
return self.mps

class PatchAmplitudeEmbedding(StateVectorToMPSEmbedding):
def __init__(self, k = 2, **kwargs):
"""
Initialize the AmplitudeToMPSEmbedding class.
Parameters
----------
k: int
The kernel size of the patch window kxk.
Returns
-------
None
"""
super().__init__(**kwargs)
self.k = k
self.mps = None

@property
def dims(self) -> list:
return list([tensor.shape for tensor in self.mps.tensors])

def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray:
"""
Create a statevector representation of an input array (vector like).
Parameters
----------
x: :class:`jax.numpy.ndarray`
An array of patch pixel intensities flattened from original image.
Returns
-------
:class:`jax.numpy.ndarray`
A statevector representation of the input array.
"""

# Number of pixels (N = 784 for a 28x28 image)
N = len(x)
# Number of address qubits is ceil(log2(N)) = 10 for a 28x28 image
n_qubits = int(np.ceil(np.log2(N)))
# Create the state vector and fill it with square roots of the pixel values
statevector = jnp.sqrt(x)

# Normalize the statevector
statevector /= jnp.linalg.norm(statevector)


# Pad or truncate to fixed size
fixed_size = 2**n_qubits
padded_statevector = self.pad_or_truncate_statevector(statevector.flatten(), fixed_size)
return padded_statevector, n_qubits

def pad_or_truncate_statevector(self, statevector: jnp.ndarray, target_size: int) -> jnp.ndarray:
"""
Pad or truncate the statevector to a target size.
Parameters
----------
statevector: :class:`jax.numpy.ndarray`
The input statevector.
target_size: int
The desired size of the statevector.
Returns
-------
:class:`jax.numpy.ndarray`
A statevector of the target size.
"""
current_size = statevector.shape[0]

# Pad or truncate
if current_size < target_size:
# Pad with zeros if smaller than target size
padding = [(0, target_size - current_size)]
statevector = jnp.pad(statevector, padding, mode='constant')
else:
# Truncate if larger than target size
statevector = statevector[:target_size]

return statevector

def combine_mps_patches(self, mps_patches: onp.ndarray, n_qubits: int) -> jnp.ndarray:
"""
Combine arrays of each MPS patch into a single MPS.
Parameters
----------
mps_patches: :class:`numpy.ndarray`
List of MPS patches (nested lists of arrays).
n_qubits: int
Number of qubits.
Returns
-------
:class:`jax.numpy.ndarray`
A list of arrays for combined MPS.
"""
new_arrays = []
number_interval = 0

for patch in mps_patches:
for i, arr in enumerate(patch):
# Check if current array index matches the start or end of an interval
if i == number_interval * n_qubits and len(arr.shape) == 2:
# Add a new axis at the beginning (dim=0)
new_arrays.append(jnp.expand_dims(arr, axis=0))
elif i == ((number_interval + 1) * n_qubits - 1) and len(arr.shape) == 2:
# Add a new axis at the end (dim=-1)
new_arrays.append(jnp.expand_dims(arr, axis=-1))
number_interval += 1
else:
# Add the array as is
new_arrays.append(arr)

return new_arrays

def __call__(self, x: jnp.ndarray) -> qtn.MatrixProductState:
"""
Convert a Statevector into a Matrix Product State (MPS).
Parameters
----------
x: :class:`jax.numpy.ndarray`
A Statevector.
Returns
-------
:class:`quimb.tensor.MatrixProductState`
A Matrix Product State representation of the input Statevector.
"""

H, W = x.shape # H: height, W: width patches: number of patches
if H != W: # TODO: Discuss about rectangular images (they could be supported, at least in principle)
raise ValueError("Only square matrix input is supported.")
if self.k > H:
raise ValueError(f"Patch dimension k = {self.k} is too large for {H}x{W} images.")

patches = u.divide_into_patches(x, self.k)
mps_patches = []

for patch in patches:
statevector, n_qubits = self.create_statevector(patch.ravel())
mps_arrays = u.from_dense_to_mps(statevector, n_qubits, self.max_bond)

mps_patches.append(mps_arrays)

new_arrays = self.combine_mps_patches(mps_patches, n_qubits)

# Recreate the MPS with the reshaped arrays
self.mps = qtn.MatrixProductState(new_arrays, shape='lrp')
return self.mps

def embed(x: onp.ndarray, phi: Union[Embedding, ComplexEmbedding, StateVectorToMPSEmbedding], **mps_opts):
"""Creates a product state from a vector of features `x`.
Works only if features are separated and not correlated (this check you need to do yourself).
Expand Down Expand Up @@ -707,4 +860,4 @@ def embed(x: onp.ndarray, phi: Union[Embedding, ComplexEmbedding, StateVectorToM
norm = mps.norm()
for tensor in mps.tensors:
tensor.modify(data=tensor.data / a.do("power", norm, 1 / len(mps.tensors)))
return mps
return mps

0 comments on commit a81e7fb

Please sign in to comment.