Skip to content

Commit

Permalink
low level flags for hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsignorelli committed Aug 9, 2024
1 parent e6b4993 commit c1696e3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GTPSA"
uuid = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
authors = ["Matt Signorelli"]
version = "1.1.0"
version = "1.1.1"

[deps]
GTPSA_jll = "a4739e29-4b97-5c0b-bbcf-46f08034c990"
Expand Down
31 changes: 20 additions & 11 deletions src/getset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ end


"""
GTPSA.hessian!(result, t::TPS; include_params=false)
GTPSA.hessian!(result, t::TPS; include_params=false, tmp_mono::Union{Nothing,Vector{UInt8}}=nothing, unsafe_fast::Bool=false)
Extracts the second-order partial derivatives (evaluated at 0) from the TPS
and fills the `result` matrix in-place. The partial derivatives wrt the parameters will
Expand All @@ -503,11 +503,13 @@ in the TPS.
### Input
- `t` -- `TPS`/`ComplexTPS64` to extract the Hessian from
- `include_params` -- (Optional) Extract partial derivatives wrt parameters. Default is false
- `tmp_mono` -- (Optional) `Vector{UInt8}` to store the monomial, when different orders of truncation are used
- `unsafe_fast` -- (Optional) Flag to specify that "fast" indexing should be used without checking. This will give incorrect results if any variable has a TO < 2. Default is `false`.
### Output
- `result` -- Matrix to fill with the Hessian of the TPS, must be 1-based indexing
"""
function hessian!(result, t::TPS; include_params=false)
function hessian!(result, t::TPS; include_params=false, tmp_mono::Union{Nothing,Vector{UInt8}}=nothing, unsafe_fast::Bool=false)
Base.require_one_based_indexing(result)
desc = unsafe_load(t.d)
n = desc.nv
Expand All @@ -520,16 +522,18 @@ function hessian!(result, t::TPS; include_params=false)

# If all variables/variable+parameters have truncation order > 2, then
# the indexing is known beforehand and we can do it faster
check = true
i = 1
while check && i <= n
if unsafe_load(desc.no, i) < 0x2
check = false
fast = true
if !unsafe_fast
i = 1
while fast && i <= n
if unsafe_load(desc.no, i) < 0x2
fast = false # use "slow" indexing
end
i += 1
end
i += 1
end
#check=false
if check

if fast
idx = desc.nv+desc.np
endidx = floor(n*(n+1)/2)+nn
curdiag = 1
Expand Down Expand Up @@ -559,7 +563,12 @@ function hessian!(result, t::TPS; include_params=false)
# of the Hessian itself (this is just a getter)
idx = desc.nv+desc.np # start at 2nd order
v = Ref{eltype(t)}()
mono = Vector{UInt8}(undef, nn)
if isnothing(tmp_mono)
mono = Vector{UInt8}(undef, nn)
else
length(tmp_mono) == nn || error("length(tmp_mono) must be $nn, received $(length(tmp_mono))")
mono = tmp_mono
end
idx = cycle!(t, idx, nn, mono, v)
while idx > 0
if sum(mono) > 0x2
Expand Down

0 comments on commit c1696e3

Please sign in to comment.