diff --git a/src/basis_structure.jl b/src/basis_structure.jl index 70120bc..961567b 100644 --- a/src/basis_structure.jl +++ b/src/basis_structure.jl @@ -273,19 +273,34 @@ function BasisMatrix{N,BT,T2}(::Type{T2}, basis::Basis{N,BT}, ::Tensor, BasisMatrix{Tensor,val_type}(out_order, vals) end +# When the user doesn't supply a ABSR, we pick one for them. +# for x::AbstractMatrix we pick direct +# for x::TensorX we pick Tensor +function BasisMatrix{T2}(::Type{T2}, basis::Basis, x::AbstractArray, order=0) + BasisMatrix(T2, basis, Direct(), x, order) +end + +function BasisMatrix{T2}(::Type{T2}, basis::Basis, x::TensorX, order=0) + BasisMatrix(T2, basis, Tensor(), x, order) +end + # method to allow passing types instead of instances of ABSR function BasisMatrix{BST<:ABSR,T2}(::Type{T2}, basis, ::Type{BST}, x, order=0) BasisMatrix(T2, basis, BST(), x, order) end -# default method without intermediate types +function BasisMatrix{BST<:ABSR}(basis, ::Type{BST}, x, order=0) + BasisMatrix(basis, BST(), x, order) +end + +# method without vals eltypes function BasisMatrix{TBM<:ABSR}(basis::Basis, tbm::TBM, x, order=0) BasisMatrix(Void, basis, tbm, x, order) end -function BasisMatrix{BST<:ABSR}(basis, ::Type{BST}, x, order=0) - BasisMatrix(basis, BST(), x, order) +function BasisMatrix(basis::Basis, x, order=0) + BasisMatrix(Void, basis, x, order) end # method without x diff --git a/test/basis_structure.jl b/test/basis_structure.jl index 3d8a461..090d2c3 100644 --- a/test/basis_structure.jl +++ b/test/basis_structure.jl @@ -206,4 +206,20 @@ @test eltype(bs3.vals) == SplineSparse{Float32,Int} end + @testset "no ABSR" begin + bmd = BasisMatrix(mb, X) + bmt = BasisMatrix(mb, x123) + @test isa(bmd, BasisMatrix{Direct}) + @test isa(bmt, BasisMatrix{Tensor}) + + bmd2 = BasisMatrix(SplineSparse, mb, X) + bmt2 = BasisMatrix(SplineSparse, mb, x123) + @test isa(bmd2.vals[1], SplineSparse) + @test isa(bmd2.vals[2], Matrix) + @test isa(bmd2.vals[3], SplineSparse) + @test isa(bmt2.vals[1], SplineSparse) + @test isa(bmt2.vals[2], Matrix) + @test isa(bmt2.vals[3], SplineSparse) + end + end # testset