Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use the tensor_product function to replace generated functions
Browse files Browse the repository at this point in the history
KnutAM committed Jan 1, 2024
1 parent e4cb2be commit b402509
Showing 4 changed files with 160 additions and 308 deletions.
381 changes: 94 additions & 287 deletions src/tensor_products.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,94 @@
# dcontract, dot, tdot, otimes, cross
# dcontract, dot, tdot, otimes, crossa:\:TT{

@foreach for dim in 1:3
@foreach for TA in (Tensor, SymmetricTensor)
@foreach for TB in (Tensor, SymmetricTensor)
# dcontract with both tensors of even order
@tensor_product(@inline @inbounds function dcontract(A::TA{2,dim}, B::TB{2,dim})
C = A[i,j]*B[i,j]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{4,dim}, B::TB{2,dim})
C[i,j] = A[i,j,k,l]*B[k,l]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{2,dim}, B::TB{4,dim})
C[k,l] = A[i,j]*B[i,j,k,l]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{4,dim}, B::TB{4,dim})
C[i,j,k,l] = A[i,j,m,n]*B[m,n,k,l]
end, muladd)

# otimes between 2nd order tensors
@tensor_product(@inline @inbounds function otimes(A::TA{2,dim}, B::TB{2,dim})
C[i,j,k,l] = A[i,j]*B[k,l]
end)

# dot between two tensors with even order
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::TB{2,dim})
C[i,j] = A[i,k]*B[k,j]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{4,dim}, B::TB{2,dim})
C[i,j,k,l] = A[i,j,k,m]*B[m,l]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::TB{4,dim})
C[i,j,k,l] = A[i,m]*B[m,j,k,l]
end)
end

# dcontract with 3rd order tensors
@tensor_product(@inline @inbounds function dcontract(A::TA{2,dim}, B::Tensor{3,dim})
C[i] = A[k,l]*B[k,l,i]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::Tensor{3,dim}, B::TA{2,dim})
C[i] = A[i,k,l]*B[k,l]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{4,dim}, B::Tensor{3,dim})
C[i,j,m] = A[i,j,k,l]*B[k,l,m]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::Tensor{3,dim}, B::TA{4,dim})
C[i,m,n] = A[i,k,l]*B[k,l,m,n]
end, muladd)

# otimes where one argument has an odd order, and one has even order
@tensor_product(@inline @inbounds function otimes(A::Tensor{1,dim}, B::TA{2,dim})
C[i,j,k] = A[i]*B[j,k]
end)
@tensor_product(@inline @inbounds function otimes(A::TA{2,dim}, B::Tensor{1,dim})
C[i,j,k] = A[i,j]*B[k]
end)

# dot where one argument has odd order, and one has even order
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::Tensor{1,dim})
C[i] = A[i,j]*B[j]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{1,dim}, B::TA{2,dim})
C[j] = A[i]*B[i,j]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{3,dim}, B::TA{2,dim})
C[i,j,k] = A[i,j,m]*B[m,k]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::Tensor{3,dim})
C[i,j,k] = A[i,m]*B[m,j,k]
end)

end
# otimes where both tensors have odd orders
@tensor_product(@inline @inbounds function otimes(A::Tensor{1,dim}, B::Tensor{1,dim})
C[i,j] = A[i]*B[j]
end)
# Defining {3}⊗{1} and {1}⊗{3} = {4} would also be valid...

# dot where both tensors have odd orders
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{1,dim}, B::Tensor{1,dim})
C = A[i]*B[i]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{3,dim}, B::Tensor{1,dim})
C[i,j] = A[i,j,k]*B[k]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{1,dim}, B::Tensor{3,dim})
C[i,j] = A[k]*B[k,i,j]
end)
end

"""
dcontract(::SecondOrderTensor, ::SecondOrderTensor)
dcontract(::SecondOrderTensor, ::FourthOrderTensor)
@@ -22,136 +112,7 @@ julia> A ⊡ B
1.9732018397544984
```
"""
@generated function dcontract(S1::SecondOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j) = compute_index(get_base(S2), i, j)
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for i in 1:dim, j in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j))]) for i in 1:dim, j in 1:dim][:]
exp = reducer(ex1, ex2)
return quote
$(Expr(:meta, :inline))
@inbounds return $exp
end
end

@generated function dcontract(S1::SecondOrderTensor{dim}, S2::FourthOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for i in 1:dim, j in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j, k, l))]) for i in 1:dim, j in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::FourthOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, k, l))]) for k in 1:dim, l in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(k, l))]) for k in 1:dim, l in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::SecondOrderTensor{dim}, S2::Tensor{3,dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for k in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for i in 1:dim, j in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j, k))]) for i in 1:dim, j in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps) # TODO: Required?
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::Tensor{3,dim}, S2::SecondOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, k, l))]) for k in 1:dim, l in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(k, l))]) for k in 1:dim, l in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::Tensor{3,dim}, S2::FourthOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, m, n))]) for m in 1:dim, n in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, n, j, k))]) for m in 1:dim, n in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::FourthOrderTensor{dim}, S2::Tensor{3,dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m, n))]) for m in 1:dim, n in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, n, k))]) for m in 1:dim, n in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::FourthOrderTensor{dim}, S2::FourthOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m, n))]) for m in 1:dim, n in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, n, k, l))]) for m in 1:dim, n in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end
function dcontract end

const = dcontract

@@ -187,23 +148,8 @@ julia> A ⊗ B
0.654957 0.48365
```
"""
@inline function otimes(S1::Vec{dim}, S2::Vec{dim}) where {dim}
return Tensor{2, dim}(@inline function(i,j) @inbounds S1[i] * S2[j]; end)
end

@inline function otimes(S1::Vec{dim}, S2::SecondOrderTensor{dim}) where {dim}
return Tensor{3, dim}(@inline function(i,j,k) @inbounds S1[i] * S2[j,k]; end)
end
@inline function otimes(S1::SecondOrderTensor{dim}, S2::Vec{dim}) where {dim}
return Tensor{3, dim}(@inline function(i,j,k) @inbounds S1[i,j] * S2[k]; end)
end

@inline function otimes(S1::SecondOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
TensorType = getreturntype(otimes, get_base(typeof(S1)), get_base(typeof(S2)))
TensorType(@inline function(i,j,k,l) @inbounds S1[i,j] * S2[k,l]; end)
end
function otimes end

# Defining {3}⊗{1} and {1}⊗{3} = {4} would also be valid...

@inline otimes(S1::Number, S2::Number) = S1*S2
@inline otimes(S1::AbstractTensor, S2::Number) = S1*S2
@@ -338,146 +284,7 @@ julia> A ⋅ B
1.0018368881367576
```
"""
@generated function LinearAlgebra.dot(S1::Vec{dim}, S2::Vec{dim}) where {dim}
ex1 = Expr[:(get_data(S1)[$i]) for i in 1:dim]
ex2 = Expr[:(get_data(S2)[$i]) for i in 1:dim]
exp = reducer(ex1, ex2)
quote
$(Expr(:meta, :inline))
@inbounds return $exp
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::Vec{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
exps = Expr(:tuple)
for i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for j in 1:dim]
ex2 = Expr[:(get_data(S2)[$j]) for j in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Vec{dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Vec{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for j in 1:dim
ex1 = Expr[:(get_data(S1)[$i]) for i in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j))]) for i in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Vec{dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, k))]) for k in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(k, j))]) for k in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{2, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::FourthOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, k, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, l))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{4, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::FourthOrderTensor{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, j, k, l))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{4, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Tensor{3,dim}, S2::Vec{dim}) where {dim}
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$m]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{2, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Vec{dim}, S2::Tensor{3,dim}) where {dim}
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$m]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, i, j))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{2, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::Tensor{3,dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, j, k))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{3, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Tensor{3,dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, k))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{3, dim}($exps)
end
end
LinearAlgebra.dot(::AbstractTensor, ::AbstractTensor)

"""
dot(::SymmetricTensor{2})
79 changes: 66 additions & 13 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -415,19 +415,72 @@ macro tensor_product(expr, args...)
tensor_product!(expr, args...)
end

# Generate a few test cases, just to check that it works.
function m_dcontract end
function m_otimes end
function m_dot end
# Define the "foreach" macro, to allow generating functions in a loop
function getrange(expr)
@assert expr.head === :call
@assert expr.args[1] === :(:)
@assert all(x->isa(x,Number), expr.args[2:end])
@assert length(expr.args) (3,4)
if length(expr.args) == 3 # from:to range
return expr.args[2]:expr.args[3]
else #length(expr.args) == 4 # from:step:to range
return expr.args[2]:expr.args[3]:expr.args[4]
end
end

@tensor_product (function m_dcontract(A::Tensor{2,3}, B::Tensor{2,3})
C = A[i,j]*B[i,j]
end)
function getiterable(expr)
if expr.head === :call && expr.args[1] === :(:)
return getrange(expr)
elseif expr.head === :(tuple)
return (a for a in expr.args)
else
error("Don't know what to do with $(expr.head)")
end
end

@tensor_product (function m_otimes(A::Tensor{1,3}, B::Tensor{1,3})
C[i,j] = A[i]*B[j]
end)
function loop_over_cases(loopsym, cases, expr)
exprs = Expr(:tuple)
for loopvar in getiterable(cases)
tmpexpr = deepcopy(expr)
f(s::Symbol) = (s === loopsym ? loopvar : s)
Tensors.replace_args!(f, tmpexpr.args)
push!(exprs.args, esc(tmpexpr))
end
return exprs
end

@tensor_product (function m_dot(A::Tensor{2,3}, B::Tensor{2,3})
C[i,j] = A[i,k]*B[k,j]
end)
function foreach(expr)
@assert expr.head === :for
loopsym = expr.args[1].args[1]
isa(loopsym, Symbol) || error("Can only loop over one variable")
cases = expr.args[1].args[2]
codeblock = expr.args[2]::Expr
@assert codeblock.head === :block
return loop_over_cases(loopsym, cases, codeblock)
end

"""
@foreach expr
Given an expression of the form
```julia
for <val> in <range_or_tuple>
<any code>
end
```
Return one expression for each item in `<range_or_tuple>`, in which all instances of `<val>`
in `<any code>` is replaced by the value in `<range_or_tuple>`. `<range_or_tuple>` must be
hard-coded. Example
```julia
@foreach for dim in 1:3
@foreach for TT in (Tensor, SymmetricTensor)
Tensors.@tensor_product(@inline @inbounds function my_dot(A::TT{2,dim}, B::TT{2,dim})
C[i,j] = A[i,k]*B[k,j]
end)
end
end
```
"""
macro foreach(expr)
return foreach(expr)
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@ end

reset_timer!()

include("test_dotmacro.jl")
include("F64.jl")
include("test_misc.jl")
include("test_ops.jl")
7 changes: 0 additions & 7 deletions test/test_dotmacro.jl

This file was deleted.

0 comments on commit b402509

Please sign in to comment.