Skip to content

Commit

Permalink
Merge pull request #36 from Bihaqo/develop
Browse files Browse the repository at this point in the history
0.2.0
  • Loading branch information
Alexander Novikov authored Mar 23, 2017
2 parents 53e6345 + de08139 commit 52f04b6
Show file tree
Hide file tree
Showing 20 changed files with 2,379 additions and 423 deletions.
36 changes: 36 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Change Log
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).

## [Unreleased]

## [0.2.0] - 2017-03-23
### Added
- (Partial) support for batches of TT-tensors.
- Riemannian module (projection on the tangent space).
- op property and str method for TensorTrain
- concat_along_batch_dim
- expand_batch_dim
- gram_matrix
- Multiplication by a number

### Changed
- Fix add function for dtypes not equal tf.float32
- flat_inner and quadratic_form now return numbers (instead of 1 x 1 tensors)

## [0.1.0] - 2017-03-12
### Added
- Indexing (e.g. TensorTrain[:, 3, 2:4])
- Full (converting TT to dense)
- TT-SVD and rounding
- Basic arithmetic (add, multiply, matmul, flat_inner)
- Variables support
- Kronecker module (functions for TT-rank 1 TT-matrices)
- quadratic_form
- frobenius_norm

[Unreleased]: https://github.com/Bihaqo/t3f/compare/master...develop
[0.2.0]: https://github.com/Bihaqo/t3f/compare/0.1.0...0.2.0
[0.1.0]: https://github.com/Bihaqo/t3f/compare/f24409508...0.1.0
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TensorFlow implementation of the Tensor Train (TT) -Toolbox.

# Installation
First, [install TensorFlow](https://www.tensorflow.org/install/). Then simply run
First, [install TensorFlow](https://www.tensorflow.org/install/) v1 or higher. Then simply run
```bash
pip install t3f
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup

setup(name='t3f',
version='0.1.0',
version='0.2.0',
description='Tensor Train decomposition on TensorFlow',
url='https://github.com/Bihaqo/t3f',
author='Alexander Novikov',
Expand Down
6 changes: 5 additions & 1 deletion t3f/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from tensor_train import *
from tensor_train_base import TensorTrainBase
from tensor_train import TensorTrain
from tensor_train_batch import TensorTrainBatch
from variables import *
from ops import *
from batch_ops import *
from initializers import *
from regularizers import *
from riemannian import *
from shapes import *
from decompositions import *
97 changes: 97 additions & 0 deletions t3f/batch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import tensorflow as tf

from tensor_train_base import TensorTrainBase
from tensor_train_batch import TensorTrainBatch
import ops


def concat_along_batch_dim(tt_list):
"""Concat all TensorTrainBatch objects along batch dimension.
Args:
tt_list: a list of TensorTrainBatch objects.
Returns:
TensorTrainBatch
"""
ndims = tt_list[0].ndims()

if isinstance(tt_list, TensorTrainBase):
# Not a list but just one element, nothing to concat.
return tt_list

for batch_idx in range(len(tt_list)):
if not isinstance(tt_list[batch_idx], TensorTrainBatch):
raise ValueError('All objects in the list should be TTBatch objects, got '
'%s' % tt_list[batch_idx])
for batch_idx in range(1, len(tt_list)):
if tt_list[batch_idx].get_raw_shape() != tt_list[0].get_raw_shape():
raise ValueError('Shapes of all TT-batch objects should coincide, got %s '
'and %s' % (tt_list[0].get_raw_shape(),
tt_list[batch_idx].get_raw_shape()))
if tt_list[batch_idx].get_tt_ranks() != tt_list[0].get_tt_ranks():
raise ValueError('TT-ranks of all TT-batch objects should coincide, got '
'%s and %s' % (tt_list[0].get_tt_ranks(),
tt_list[batch_idx].get_tt_ranks()))

res_cores = []
for core_idx in range(ndims):
curr_core = tf.concat([tt.tt_cores[core_idx] for tt in tt_list], axis=0)
res_cores.append(curr_core)

batch_size = sum([tt.batch_size for tt in tt_list])

return TensorTrainBatch(res_cores, tt_list[0].get_raw_shape(),
tt_list[0].get_tt_ranks(), batch_size)


def gram_matrix(tt_vectors, matrix=None):
"""Computes Gramian matrix of a batch of TT-vecors.
If matrix is None, computes
res[i, j] = t3f.flat_inner(tt_vectors[i], tt_vectors[j]).
If matrix is present, computes
res[i, j] = t3f.flat_inner(tt_vectors[i], t3f.matmul(matrix, tt_vectors[j]))
or more shorly
res[i, j] = tt_vectors[i]^T * matrix * tt_vectors[j]
Args:
tt_vectors: TensorTrainBatch.
matrix: None, or TensorTrain matrix.
Returns:
tf.tensor with the Gram matrix.
"""
ndims = tt_vectors.ndims()
if matrix is None:
curr_core = tt_vectors.tt_cores[0]
res = tf.einsum('paijb,qcijd->pqbd', curr_core, curr_core)
for core_idx in range(1, ndims):
curr_core = tt_vectors.tt_cores[core_idx]
res = tf.einsum('pqac,paijb,qcijd->pqbd', res, curr_core, curr_core)
else:
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
vectors_shape = tt_vectors.get_shape()
if vectors_shape[2] == 1 and vectors_shape[1] != 1:
# TODO: not very efficient, better to use different order in einsum.
tt_vectors = ops.transpose(tt_vectors)
vectors_shape = tt_vectors.get_shape()
if vectors_shape[1] != 1:
# TODO: do something so that in case the shape is undefined on compilation
# it still works.
raise ValueError('The tt_vectors argument should be vectors (not '
'matrices) with shape defined on compilation.')
curr_core = tt_vectors.tt_cores[0]
curr_matrix_core = matrix.tt_cores[0]
# We enumerate the dummy dimension (that takes 1 value) with `k`.
res = tf.einsum('pakib,cijd,qekjf->pqbdf', curr_core, curr_matrix_core,
curr_core)
for core_idx in range(1, ndims):
curr_core = tt_vectors.tt_cores[core_idx]
curr_matrix_core = matrix.tt_cores[core_idx]
res = tf.einsum('pqace,pakib,cijd,qekjf->pqbdf', res, curr_core,
curr_matrix_core, curr_core)

# Squeeze to make the result of size batch_size x batch_size instead of
# batch_size x batch_size x 1 x 1.
return tf.squeeze(res)
79 changes: 79 additions & 0 deletions t3f/batch_ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import tensorflow as tf

from tensor_train import TensorTrain
from tensor_train_batch import TensorTrainBatch
import ops
import batch_ops
import initializers


class BatchOpsTest(tf.test.TestCase):

def testConcatMatrix(self):
# Test concating TTMatrix batches along batch dimension.
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1)
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4)
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3)
first_res = batch_ops.concat_along_batch_dim((first))
first_res = ops.full(first_res)
first_second_res = batch_ops.concat_along_batch_dim((first, second))
first_second_res = ops.full(first_second_res)
first_second_third_res = batch_ops.concat_along_batch_dim((first, second,
third))
first_second_third_res = ops.full(first_second_third_res)

first_full = ops.full(first)
second_full = ops.full(second)
third_full = ops.full(third)
first_desired = first_full
first_second_desired = tf.concat((first_full, second_full), axis=0)
first_second_third_desired = tf.concat((first_full, second_full, third_full),
axis=0)
with self.test_session() as sess:
res = sess.run((first_res, first_second_res, first_second_third_res,
first_desired, first_second_desired,
first_second_third_desired))
first_res_val = res[0]
first_second_res_val = res[1]
first_second_third_res_val = res[2]
first_desired_val = res[3]
first_second_desired_val = res[4]
first_second_third_desired_val = res[5]
self.assertAllClose(first_res_val, first_desired_val)
self.assertAllClose(first_second_res_val, first_second_desired_val)
self.assertAllClose(first_second_third_res_val, first_second_third_desired_val)

def testGramMatrix(self):
# Test Gram Matrix of a batch of TT vectors.
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5)
res_actual = batch_ops.gram_matrix(tt_vectors)
full_vectors = tf.reshape(ops.full(tt_vectors), (5, 6))
res_desired = tf.matmul(full_vectors, tf.transpose(full_vectors))
res_desired = tf.squeeze(res_desired)
with self.test_session() as sess:
res_actual_val, res_desired_val = sess.run((res_actual, res_desired))
self.assertAllClose(res_desired_val, res_actual_val)

def testGramMatrixWithMatrix(self):
# Test Gram Matrix of a batch of TT vectors with providing a matrix, so we
# should compute
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
tt_vectors = initializers.random_matrix_batch((None, (2, 3)), batch_size=4)
matrix = initializers.random_matrix(((2, 3), (2, 3)))
res_actual = batch_ops.gram_matrix(tt_vectors, matrix)
full_vectors = tf.reshape(ops.full(tt_vectors), (4, 6))
with self.test_session() as sess:
res = sess.run((res_actual, full_vectors, ops.full(matrix)))
res_actual_val, vectors_val, matrix_val = res
res_desired_val = np.zeros((4, 4))
for i in range(4):
for j in range(4):
curr_val = np.dot(vectors_val[i], matrix_val)
curr_val = np.dot(curr_val, vectors_val[j])
res_desired_val[i, j] = curr_val
self.assertAllClose(res_desired_val, res_actual_val, atol=1e-5, rtol=1e-5)

if __name__ == "__main__":
tf.test.main()

Loading

0 comments on commit 52f04b6

Please sign in to comment.