Skip to content

Commit

Permalink
Merge pull request #168 from ValeevGroup/kmp5/debug/fix_flatten
Browse files Browse the repository at this point in the history
Fix the issue found in Flatten
  • Loading branch information
evaleev authored Jan 24, 2024
2 parents 86e4a21 + bc22242 commit bef8f22
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 100 deletions.
10 changes: 10 additions & 0 deletions btas/generic/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ namespace btas {
epsilon = t.get_fit(max_iter);
return;
}

template <typename T>
void set_norm(T &t, double norm){
return;
}

template <typename Tensor>
void set_norm(FitCheck<Tensor> &t, double norm){
t.set_norm(norm);
}
} // namespace detail

/** \brief Base class to compute the Canonical Product (CP) decomposition of an order-N
Expand Down
67 changes: 45 additions & 22 deletions btas/generic/cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ namespace btas {
if (tmp != i) {
A[i] = A[tmp];
} else if (dir) {
direct(i, rank, fast_pI, matlab, converge_test);
direct(i, rank, fast_pI, matlab, converge_test, tensor_ref);
} else {
update_w_KRP(i, rank, fast_pI, matlab, converge_test);
}
Expand Down Expand Up @@ -658,7 +658,7 @@ namespace btas {
double lambda = 0) {
Tensor temp(A[n].extent(0), rank);
Tensor an(A[n].range());

#ifdef BTAS_HAS_INTEL_MKL

// Computes the Khatri-Rao product intermediate
Expand Down Expand Up @@ -692,9 +692,32 @@ namespace btas {
swap_to_first(tensor_ref, n, true);

#else // BTAS_HAS_CBLAS
// without MKL program cannot perform the swapping algorithm, must compute
// flattened intermediate
gemm(blas::Op::NoTrans, blas::Op::NoTrans, this->one, flatten(tensor_ref, n), this->generate_KRP(n, rank, true), this->zero, temp);
// // Computes the Khatri-Rao product intermediate
auto KhatriRao = this->generate_KRP(n, rank, true);

// moves mode n of the reference tensor to the front to simplify contraction
std::vector<ind_t> tref_indices, KRP_dims, An_indices;

// resize the Khatri-Rao product to the proper dimensions
for (size_t i = 0; i < ndim; i++) {
tref_indices.push_back(i);
if(i == n)
continue;
KRP_dims.push_back(tensor_ref.extent(i));
}
KRP_dims.push_back(rank);
KhatriRao.resize(KRP_dims);
KRP_dims.clear();

An_indices.push_back(n);
An_indices.push_back(ndim);
for (size_t i = 0; i < ndim; i++) {
if(i == n)
continue;
KRP_dims.push_back(i);
}
KRP_dims.push_back(ndim);
contract(this->one, tensor_ref, tref_indices, KhatriRao, KRP_dims, this->zero, temp, An_indices);
#endif

if(lambda != 0){
Expand Down Expand Up @@ -737,52 +760,52 @@ namespace btas {
/// return if \c matlab was successful
/// \param[in, out] converge_test Test to see if ALS is converged, holds the value of fit. test to see if the ALS is converged

void direct(size_t n, ind_t rank, bool &fast_pI, bool &matlab, ConvClass &converge_test, double lambda = 0.0) {
void direct(size_t n, ind_t rank, bool &fast_pI, bool &matlab, ConvClass &converge_test, Tensor& target, double lambda = 0.0) {
// Determine if n is the last mode, if it is first contract with first mode
// and transpose the product
bool last_dim = n == ndim - 1;
// product of all dimensions
ord_t LH_size = size;
size_t contract_dim = last_dim ? 0 : ndim - 1;
ind_t offset_dim = tensor_ref.extent(n);
ind_t offset_dim = target.extent(n);
ind_t pseudo_rank = rank;

// Store the dimensions which are available to hadamard contract
std::vector<ind_t> dimensions;
for (size_t i = last_dim ? 1 : 0; i < (last_dim ? ndim : ndim - 1); i++) {
dimensions.push_back(tensor_ref.extent(i));
dimensions.push_back(target.extent(i));
}

// Modifying the dimension of tensor_ref so store the range here to resize
Range R = tensor_ref.range();
// Modifying the dimension of target so store the range here to resize
Range R = target.range();
//Tensor an(A[n].range());

// Resize the tensor which will store the product of tensor_ref and the first factor matrix
Tensor temp = Tensor(size / tensor_ref.extent(contract_dim), rank);
tensor_ref.resize(
Range{Range1{last_dim ? tensor_ref.extent(contract_dim) : size / tensor_ref.extent(contract_dim)},
Range1{last_dim ? size / tensor_ref.extent(contract_dim) : tensor_ref.extent(contract_dim)}});
// Resize the tensor which will store the product of target and the first factor matrix
Tensor temp = Tensor(size / target.extent(contract_dim), rank);
target.resize(
Range{Range1{last_dim ? target.extent(contract_dim) : size / target.extent(contract_dim)},
Range1{last_dim ? size / target.extent(contract_dim) : target.extent(contract_dim)}});

// contract tensor ref and the first factor matrix
gemm((last_dim ? blas::Op::Trans : blas::Op::NoTrans), blas::Op::NoTrans, this->one , (last_dim? tensor_ref.conj():tensor_ref), A[contract_dim].conj(), this->zero,
gemm((last_dim ? blas::Op::Trans : blas::Op::NoTrans), blas::Op::NoTrans, this->one , (last_dim? target.conj():target), A[contract_dim].conj(), this->zero,
temp);

// Resize tensor_ref
tensor_ref.resize(R);
// Resize target
target.resize(R);
// Remove the dimension which was just contracted out
LH_size /= tensor_ref.extent(contract_dim);
LH_size /= target.extent(contract_dim);

// n tells which dimension not to contract, and contract_dim says which dimension I am trying to contract.
// If n == contract_dim then that mode is skipped.
// if n == ndim - 1, my contract_dim = 0. The gemm transposes to make rank = ndim - 1, so I
// move the pointer that preserves the last dimension to n = ndim -2.
// In all cases I want to walk through the orders in tensor_ref backward so contract_dim = ndim - 2
// In all cases I want to walk through the orders in target backward so contract_dim = ndim - 2
n = last_dim ? ndim - 2 : n;
contract_dim = ndim - 2;

while (contract_dim > 0) {
// Now temp is three index object where temp has size
// (size of tensor_ref/product of dimension contracted, dimension to be
// (size of target/product of dimension contracted, dimension to be
// contracted, rank)
ord_t idx2 = dimensions[contract_dim],
idx1 = LH_size / idx2;
Expand All @@ -793,7 +816,7 @@ namespace btas {
//contract_tensor.fill(0.0);
const auto &a = A[(last_dim ? contract_dim + 1 : contract_dim)];
// If the middle dimension is the mode not being contracted, I will move
// it to the right hand side temp((size of tensor_ref/product of
// it to the right hand side temp((size of target/product of
// dimension contracted, rank * mode n dimension)
if (n == contract_dim) {
pseudo_rank *= offset_dim;
Expand Down
2 changes: 1 addition & 1 deletion btas/generic/cp_rals.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ namespace btas {
A[i] = A[tmp];
lambda[i] = lambda[tmp];
} else if (dir) {
this->direct(i, rank, fast_pI, matlab, converge_test, lambda[i]);
this->direct(i, rank, fast_pI, matlab, converge_test, tensor_ref, lambda[i]);
} else {
update_w_KRP(i, rank, fast_pI, matlab, converge_test, lambda[i]);
}
Expand Down
96 changes: 22 additions & 74 deletions btas/generic/flatten.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,86 +9,34 @@ namespace btas {
/// \f[ A(I_1, I_2, I_3, ..., I_{mode}, ..., I_N) -> A(I_{mode}, J)\f]
/// where \f$J = I_1 * I_2 * ...I_{mode-1} * I_{mode+1} * ... * I_N.\f$
/// \return Matrix with dimension \f$(I_{mode}, J)\f$

template<typename Tensor>
Tensor flatten(const Tensor &A, size_t mode) {
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;

if (mode >= A.rank()) BTAS_EXCEPTION("Cannot flatten along mode outside of A.rank()");

// make X the correct size
Tensor X(A.extent(mode), A.range().area() / A.extent(mode));

ord_t indexi = 0, indexj = 0;
size_t ndim = A.rank();
// J is the new step size found by removing the mode of interest
std::vector<ord_t> J(ndim, 1);
for (size_t i = 0; i < ndim; ++i)
if (i != mode)
for (size_t m = 0; m < i; ++m)
if (m != mode)
J[i] *= A.extent(m);

auto tensor_itr = A.begin();

// Fill X with the correct values
fill(A, 0, X, mode, indexi, indexj, J, tensor_itr);

// return the flattened matrix
return X;
}

/// following the formula for flattening layed out by Kolda and Bader
/// <a href=http://epubs.siam.org/doi/pdf/10.1137/07070111X> See reference. </a>
/// Recursive method utilized by flatten.\n **Important** if you want to flatten a tensor
/// call flatten, not fill.

/// \param[in] A The reference tensor to be flattened
/// \param[in] depth The recursion depth. Should not exceed the A.rank()
/// \param[in, out] X In: An empty matrix to be filled with correct
/// elements of \c A flattened on the \c mode fiber. Should be size \f$ (I_{mode}, J)\f$
/// Out: The flattened A matrix along the \c mode fiber \param[in]
/// mode The mode which A is to be flattened. \param[in] indexi The row index of
/// matrix X \param[in] indexj The column index of matrix X \param[in] J The
/// step size for the row dimension of X \param[in] tensor_itr An iterator of \c A.
/// The value of the iterator is placed in the correct position of X using
/// recursive calls of fill().

template<typename Tensor, typename iterator, typename ord_t>
void fill(const Tensor &A, size_t depth, Tensor &X, size_t mode,
ord_t indexi, ord_t indexj, const std::vector<ord_t> &J, iterator &tensor_itr) {
template<typename Tensor>
Tensor flatten(Tensor A, size_t mode) {
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
using ind_t = typename Tensor::range_type::index_type::value_type;
size_t ndim = A.rank();
if (depth < ndim) {
// We are going to first make the order N tensor into a order 3 tensor with
// (modes before `mode`, `mode`, modes after `mode`

// Creates a for loop based on the number of modes A has
for (ind_t i = 0; i < A.extent(depth); ++i) {
auto dim_mode = A.extent(mode);
Tensor flat(dim_mode, A.range().area() / dim_mode);
size_t ndim = A.rank();
ord_t dim1 = 1, dim3 = 1;
for (ind_t i = 0; i < ndim; ++i) {
if (i < mode)
dim1 *= A.extent(i);
else if (i > mode)
dim3 *= A.extent(i);
}

A.resize(Range{Range1{dim1}, Range1{dim_mode}, Range1{dim3}});

// use the for loop to find the column dimension index
if (depth != mode) {
indexj += i * J[depth]; // column matrix index
for (ord_t i = 0; i < dim1; ++i) {
for (ind_t j = 0; j < dim_mode; ++j) {
for (ord_t k = 0; k < dim3; ++k) {
flat(j, i * dim3 + k) = A(i,j,k);
}

// if this depth is the mode being flattened use the for loop to find the
// row dimension
else {
indexi = i; // row matrix index
}

fill(A, depth + 1, X, mode, indexi, indexj, J, tensor_itr);

// remove the indexing from earlier in this loop.
if (depth != mode)
indexj -= i * J[depth];
}
}

// When depth steps out of the number of dimensions, set X to be the correct
// value from the iterator then increment the iterator.
else {
X(indexi, indexj) = *tensor_itr;
tensor_itr++;
}
return flat;
}

} // namespace btas
Expand Down
1 change: 1 addition & 0 deletions btas/generic/linear_algebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Tensor pseudoInverse(Tensor & A, bool & fast_pI) {
// Compute the matrix A^-1 from the inverted singular values and the U and
// V^T provided by the SVD
gemm(blas::Op::NoTrans, blas::Op::NoTrans, 1.0, U, s_inv, 0.0, s_);
U = Tensor(Range{Range1{row}, Range1{col}});
gemm(blas::Op::NoTrans, blas::Op::NoTrans, 1.0, s_, Vt, 0.0, U);

return U;
Expand Down
2 changes: 1 addition & 1 deletion btas/generic/tuck_cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ namespace btas{
count++;
this->num_ALS++;
for (size_t i = 0; i < ndim; i++) {
this->direct(i, rank, fast_pI, matlab, converge_test, lambda[i]);
this->direct(i, rank, fast_pI, matlab, converge_test, tensor_ref, lambda[i]);
// Compute the value s after normalizing the columns
auto & ai = A[i];
this->s = helper(i, ai);
Expand Down
4 changes: 2 additions & 2 deletions unittest/ztensor_cp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ TEST_CASE("ZCP") {
SECTION("RALS MODE = 4, Finite rank") {
CP_RALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_rank(65, conv, 1, true, 65);
double diff = A1.compute_rank(67, conv, 1, true, 65);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("RALS MODE = 4, Finite error"){
CP_RALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_error(conv, 1e-2, 1, 67, true, 65);
double diff = A1.compute_error(conv, 1e-5, 1, 67, true, 65);
CHECK(std::abs(diff) <= epsilon);
}
#if BTAS_ENABLE_TUCKER_CP_UT
Expand Down

0 comments on commit bef8f22

Please sign in to comment.