Skip to content

Commit

Permalink
Specialize triu/tril for StaticMatrix (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Oct 2, 2024
1 parent 431d57a commit 7708986
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ using LinearAlgebra
import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr,
kron, diag, norm, dot, diagm, lu, svd, svdvals, pinv,
factorize, ishermitian, issymmetric, isposdef, issuccess, normalize,
normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \
normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \,
triu, tril
using LinearAlgebra: checksquare

using PrecompileTools
Expand Down
34 changes: 34 additions & 0 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,37 @@ end
# Some shimming for special linear algebra matrix types
@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, uplo))
@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Hermitian{eltype(A),typeof(A)}(A, uplo))

# triu/tril
function triu(S::StaticMatrix, k::Int=0)
if length(S) <= 32
C = CartesianIndices(S)
t = Tuple(S)
for (linind, CI) in enumerate(C)
i, j = Tuple(CI)
if j-i < k
t = Base.setindex(t, zero(t[linind]), linind)
end
end
similar_type(S)(t)
else
M = triu!(copyto!(similar(S), S), k)
similar_type(S)(M)
end
end
function tril(S::StaticMatrix, k::Int=0)
if length(S) <= 32
C = CartesianIndices(S)
t = Tuple(S)
for (linind, CI) in enumerate(C)
i, j = Tuple(CI)
if j-i > k
t = Base.setindex(t, zero(t[linind]), linind)
end
end
similar_type(S)(t)
else
M = tril!(copyto!(similar(S), S), k)
similar_type(S)(M)
end
end
10 changes: 10 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,4 +471,14 @@ end
m23 = SA[1 2 3; 4 5 6]
@test_inlined checksquare(m23) false
end

@testset "triu/tril" begin
for S in (SMatrix{7,5}(1:35), MMatrix{4,6}(1:24), SizedArray{Tuple{2,2}}([1 2; 3 4]))
M = Matrix(S)
for k in -10:10
@test triu(S, k) == triu(M, k)
@test tril(S, k) == tril(M, k)
end
end
end
end

0 comments on commit 7708986

Please sign in to comment.