Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use the tensor_product function to replace generated functions
Browse files Browse the repository at this point in the history
KnutAM committed Jan 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent e4cb2be commit b402509
Showing 4 changed files with 160 additions and 308 deletions.
381 changes: 94 additions & 287 deletions src/tensor_products.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,94 @@
# dcontract, dot, tdot, otimes, cross
# dcontract, dot, tdot, otimes, crossa:\:TT{

@foreach for dim in 1:3
@foreach for TA in (Tensor, SymmetricTensor)
@foreach for TB in (Tensor, SymmetricTensor)
# dcontract with both tensors of even order
@tensor_product(@inline @inbounds function dcontract(A::TA{2,dim}, B::TB{2,dim})
C = A[i,j]*B[i,j]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{4,dim}, B::TB{2,dim})
C[i,j] = A[i,j,k,l]*B[k,l]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{2,dim}, B::TB{4,dim})
C[k,l] = A[i,j]*B[i,j,k,l]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{4,dim}, B::TB{4,dim})
C[i,j,k,l] = A[i,j,m,n]*B[m,n,k,l]
end, muladd)

# otimes between 2nd order tensors
@tensor_product(@inline @inbounds function otimes(A::TA{2,dim}, B::TB{2,dim})
C[i,j,k,l] = A[i,j]*B[k,l]
end)

# dot between two tensors with even order
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::TB{2,dim})
C[i,j] = A[i,k]*B[k,j]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{4,dim}, B::TB{2,dim})
C[i,j,k,l] = A[i,j,k,m]*B[m,l]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::TB{4,dim})
C[i,j,k,l] = A[i,m]*B[m,j,k,l]
end)
end

# dcontract with 3rd order tensors
@tensor_product(@inline @inbounds function dcontract(A::TA{2,dim}, B::Tensor{3,dim})
C[i] = A[k,l]*B[k,l,i]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::Tensor{3,dim}, B::TA{2,dim})
C[i] = A[i,k,l]*B[k,l]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::TA{4,dim}, B::Tensor{3,dim})
C[i,j,m] = A[i,j,k,l]*B[k,l,m]
end, muladd)
@tensor_product(@inline @inbounds function dcontract(A::Tensor{3,dim}, B::TA{4,dim})
C[i,m,n] = A[i,k,l]*B[k,l,m,n]
end, muladd)

# otimes where one argument has an odd order, and one has even order
@tensor_product(@inline @inbounds function otimes(A::Tensor{1,dim}, B::TA{2,dim})
C[i,j,k] = A[i]*B[j,k]
end)
@tensor_product(@inline @inbounds function otimes(A::TA{2,dim}, B::Tensor{1,dim})
C[i,j,k] = A[i,j]*B[k]
end)

# dot where one argument has odd order, and one has even order
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::Tensor{1,dim})
C[i] = A[i,j]*B[j]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{1,dim}, B::TA{2,dim})
C[j] = A[i]*B[i,j]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{3,dim}, B::TA{2,dim})
C[i,j,k] = A[i,j,m]*B[m,k]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::TA{2,dim}, B::Tensor{3,dim})
C[i,j,k] = A[i,m]*B[m,j,k]
end)

end
# otimes where both tensors have odd orders
@tensor_product(@inline @inbounds function otimes(A::Tensor{1,dim}, B::Tensor{1,dim})
C[i,j] = A[i]*B[j]
end)
# Defining {3}⊗{1} and {1}⊗{3} = {4} would also be valid...

# dot where both tensors have odd orders
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{1,dim}, B::Tensor{1,dim})
C = A[i]*B[i]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{3,dim}, B::Tensor{1,dim})
C[i,j] = A[i,j,k]*B[k]
end)
@tensor_product(@inline @inbounds function LinearAlgebra.dot(A::Tensor{1,dim}, B::Tensor{3,dim})
C[i,j] = A[k]*B[k,i,j]
end)
end

"""
dcontract(::SecondOrderTensor, ::SecondOrderTensor)
dcontract(::SecondOrderTensor, ::FourthOrderTensor)
@@ -22,136 +112,7 @@ julia> A ⊡ B
1.9732018397544984
```
"""
@generated function dcontract(S1::SecondOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j) = compute_index(get_base(S2), i, j)
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for i in 1:dim, j in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j))]) for i in 1:dim, j in 1:dim][:]
exp = reducer(ex1, ex2)
return quote
$(Expr(:meta, :inline))
@inbounds return $exp
end
end

@generated function dcontract(S1::SecondOrderTensor{dim}, S2::FourthOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for i in 1:dim, j in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j, k, l))]) for i in 1:dim, j in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::FourthOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, k, l))]) for k in 1:dim, l in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(k, l))]) for k in 1:dim, l in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::SecondOrderTensor{dim}, S2::Tensor{3,dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for k in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for i in 1:dim, j in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j, k))]) for i in 1:dim, j in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps) # TODO: Required?
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::Tensor{3,dim}, S2::SecondOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, k, l))]) for k in 1:dim, l in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(k, l))]) for k in 1:dim, l in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::Tensor{3,dim}, S2::FourthOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, m, n))]) for m in 1:dim, n in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, n, j, k))]) for m in 1:dim, n in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::FourthOrderTensor{dim}, S2::Tensor{3,dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m, n))]) for m in 1:dim, n in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, n, k))]) for m in 1:dim, n in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end

@generated function dcontract(S1::FourthOrderTensor{dim}, S2::FourthOrderTensor{dim}) where {dim}
TensorType = getreturntype(dcontract, get_base(S1), get_base(S2))
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m, n))]) for m in 1:dim, n in 1:dim][:]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, n, k, l))]) for m in 1:dim, n in 1:dim][:]
push!(exps.args, reducer(ex1, ex2, true))
end
expr = remove_duplicates(TensorType, exps)
quote
$(Expr(:meta, :inline))
@inbounds return $TensorType($expr)
end
end
function dcontract end

const = dcontract

@@ -187,23 +148,8 @@ julia> A ⊗ B
0.654957 0.48365
```
"""
@inline function otimes(S1::Vec{dim}, S2::Vec{dim}) where {dim}
return Tensor{2, dim}(@inline function(i,j) @inbounds S1[i] * S2[j]; end)
end

@inline function otimes(S1::Vec{dim}, S2::SecondOrderTensor{dim}) where {dim}
return Tensor{3, dim}(@inline function(i,j,k) @inbounds S1[i] * S2[j,k]; end)
end
@inline function otimes(S1::SecondOrderTensor{dim}, S2::Vec{dim}) where {dim}
return Tensor{3, dim}(@inline function(i,j,k) @inbounds S1[i,j] * S2[k]; end)
end

@inline function otimes(S1::SecondOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
TensorType = getreturntype(otimes, get_base(typeof(S1)), get_base(typeof(S2)))
TensorType(@inline function(i,j,k,l) @inbounds S1[i,j] * S2[k,l]; end)
end
function otimes end

# Defining {3}⊗{1} and {1}⊗{3} = {4} would also be valid...

@inline otimes(S1::Number, S2::Number) = S1*S2
@inline otimes(S1::AbstractTensor, S2::Number) = S1*S2
@@ -338,146 +284,7 @@ julia> A ⋅ B
1.0018368881367576
```
"""
@generated function LinearAlgebra.dot(S1::Vec{dim}, S2::Vec{dim}) where {dim}
ex1 = Expr[:(get_data(S1)[$i]) for i in 1:dim]
ex2 = Expr[:(get_data(S2)[$i]) for i in 1:dim]
exp = reducer(ex1, ex2)
quote
$(Expr(:meta, :inline))
@inbounds return $exp
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::Vec{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
exps = Expr(:tuple)
for i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j))]) for j in 1:dim]
ex2 = Expr[:(get_data(S2)[$j]) for j in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Vec{dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Vec{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for j in 1:dim
ex1 = Expr[:(get_data(S1)[$i]) for i in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(i, j))]) for i in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Vec{dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, k))]) for k in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(k, j))]) for k in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{2, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::FourthOrderTensor{dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j, k, l) = compute_index(get_base(S1), i, j, k, l)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, k, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, l))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{4, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::FourthOrderTensor{dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k, l) = compute_index(get_base(S2), i, j, k, l)
exps = Expr(:tuple)
for l in 1:dim, k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, j, k, l))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{4, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Tensor{3,dim}, S2::Vec{dim}) where {dim}
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$m]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{2, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Vec{dim}, S2::Tensor{3,dim}) where {dim}
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$m]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, i, j))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{2, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::SecondOrderTensor{dim}, S2::Tensor{3,dim}) where {dim}
idxS1(i, j) = compute_index(get_base(S1), i, j)
idxS2(i, j, k) = compute_index(get_base(S2), i, j, k)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, j, k))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{3, dim}($exps)
end
end

@generated function LinearAlgebra.dot(S1::Tensor{3,dim}, S2::SecondOrderTensor{dim}) where {dim}
idxS1(i, j, k) = compute_index(get_base(S1), i, j, k)
idxS2(i, j) = compute_index(get_base(S2), i, j)
exps = Expr(:tuple)
for k in 1:dim, j in 1:dim, i in 1:dim
ex1 = Expr[:(get_data(S1)[$(idxS1(i, j, m))]) for m in 1:dim]
ex2 = Expr[:(get_data(S2)[$(idxS2(m, k))]) for m in 1:dim]
push!(exps.args, reducer(ex1, ex2))
end
quote
$(Expr(:meta, :inline))
@inbounds return Tensor{3, dim}($exps)
end
end
LinearAlgebra.dot(::AbstractTensor, ::AbstractTensor)

"""
dot(::SymmetricTensor{2})
79 changes: 66 additions & 13 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -415,19 +415,72 @@ macro tensor_product(expr, args...)
tensor_product!(expr, args...)
end

# Generate a few test cases, just to check that it works.
function m_dcontract end
function m_otimes end
function m_dot end
# Define the "foreach" macro, to allow generating functions in a loop
function getrange(expr)
@assert expr.head === :call
@assert expr.args[1] === :(:)
@assert all(x->isa(x,Number), expr.args[2:end])
@assert length(expr.args) (3,4)
if length(expr.args) == 3 # from:to range
return expr.args[2]:expr.args[3]
else #length(expr.args) == 4 # from:step:to range
return expr.args[2]:expr.args[3]:expr.args[4]
end
end

@tensor_product (function m_dcontract(A::Tensor{2,3}, B::Tensor{2,3})
C = A[i,j]*B[i,j]
end)
function getiterable(expr)
if expr.head === :call && expr.args[1] === :(:)
return getrange(expr)
elseif expr.head === :(tuple)
return (a for a in expr.args)
else
error("Don't know what to do with $(expr.head)")
end
end

@tensor_product (function m_otimes(A::Tensor{1,3}, B::Tensor{1,3})
C[i,j] = A[i]*B[j]
end)
function loop_over_cases(loopsym, cases, expr)
exprs = Expr(:tuple)
for loopvar in getiterable(cases)
tmpexpr = deepcopy(expr)
f(s::Symbol) = (s === loopsym ? loopvar : s)
Tensors.replace_args!(f, tmpexpr.args)
push!(exprs.args, esc(tmpexpr))
end
return exprs
end

@tensor_product (function m_dot(A::Tensor{2,3}, B::Tensor{2,3})
C[i,j] = A[i,k]*B[k,j]
end)
function foreach(expr)
@assert expr.head === :for
loopsym = expr.args[1].args[1]
isa(loopsym, Symbol) || error("Can only loop over one variable")
cases = expr.args[1].args[2]
codeblock = expr.args[2]::Expr
@assert codeblock.head === :block
return loop_over_cases(loopsym, cases, codeblock)
end

"""
@foreach expr
Given an expression of the form
```julia
for <val> in <range_or_tuple>
<any code>
end
```
Return one expression for each item in `<range_or_tuple>`, in which all instances of `<val>`
in `<any code>` is replaced by the value in `<range_or_tuple>`. `<range_or_tuple>` must be
hard-coded. Example
```julia
@foreach for dim in 1:3
@foreach for TT in (Tensor, SymmetricTensor)
Tensors.@tensor_product(@inline @inbounds function my_dot(A::TT{2,dim}, B::TT{2,dim})
C[i,j] = A[i,k]*B[k,j]
end)
end
end
```
"""
macro foreach(expr)
return foreach(expr)
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@ end

reset_timer!()

include("test_dotmacro.jl")
include("F64.jl")
include("test_misc.jl")
include("test_ops.jl")
7 changes: 0 additions & 7 deletions test/test_dotmacro.jl

This file was deleted.

0 comments on commit b402509

Please sign in to comment.