diff --git a/Project.toml b/Project.toml index 2f1be08..ccd5460 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GTPSA" uuid = "b27dd330-f138-47c5-815b-40db9dd9b6e8" authors = ["Matt Signorelli"] -version = "1.1.0" +version = "1.1.1" [deps] GTPSA_jll = "a4739e29-4b97-5c0b-bbcf-46f08034c990" diff --git a/src/getset.jl b/src/getset.jl index 3f75c14..b259892 100644 --- a/src/getset.jl +++ b/src/getset.jl @@ -492,7 +492,7 @@ end """ - GTPSA.hessian!(result, t::TPS; include_params=false) + GTPSA.hessian!(result, t::TPS; include_params=false, tmp_mono::Union{Nothing,Vector{UInt8}}=nothing, unsafe_fast::Bool=false) Extracts the second-order partial derivatives (evaluated at 0) from the TPS and fills the `result` matrix in-place. The partial derivatives wrt the parameters will @@ -503,11 +503,13 @@ in the TPS. ### Input - `t` -- `TPS`/`ComplexTPS64` to extract the Hessian from - `include_params` -- (Optional) Extract partial derivatives wrt parameters. Default is false +- `tmp_mono` -- (Optional) `Vector{UInt8}` to store the monomial, when different orders of truncation are used +- `unsafe_fast` -- (Optional) Flag to specify that "fast" indexing should be used without checking. This will give incorrect results if any variable has a TO < 2. Default is `false`. ### Output - `result` -- Matrix to fill with the Hessian of the TPS, must be 1-based indexing """ -function hessian!(result, t::TPS; include_params=false) +function hessian!(result, t::TPS; include_params=false, tmp_mono::Union{Nothing,Vector{UInt8}}=nothing, unsafe_fast::Bool=false) Base.require_one_based_indexing(result) desc = unsafe_load(t.d) n = desc.nv @@ -520,16 +522,18 @@ function hessian!(result, t::TPS; include_params=false) # If all variables/variable+parameters have truncation order > 2, then # the indexing is known beforehand and we can do it faster - check = true - i = 1 - while check && i <= n - if unsafe_load(desc.no, i) < 0x2 - check = false + fast = true + if !unsafe_fast + i = 1 + while fast && i <= n + if unsafe_load(desc.no, i) < 0x2 + fast = false # use "slow" indexing + end + i += 1 end - i += 1 end - #check=false - if check + + if fast idx = desc.nv+desc.np endidx = floor(n*(n+1)/2)+nn curdiag = 1 @@ -559,7 +563,12 @@ function hessian!(result, t::TPS; include_params=false) # of the Hessian itself (this is just a getter) idx = desc.nv+desc.np # start at 2nd order v = Ref{eltype(t)}() - mono = Vector{UInt8}(undef, nn) + if isnothing(tmp_mono) + mono = Vector{UInt8}(undef, nn) + else + length(tmp_mono) == nn || error("length(tmp_mono) must be $nn, received $(length(tmp_mono))") + mono = tmp_mono + end idx = cycle!(t, idx, nn, mono, v) while idx > 0 if sum(mono) > 0x2