Skip to content

Commit

Permalink
Allow FieldTimeSeries to pass keyword arguments to jldopen (#3739)
Browse files Browse the repository at this point in the history
Co-authored-by: Simone Silvestri <[email protected]>
  • Loading branch information
ali-ramadhan and simone-silvestri authored Oct 29, 2024
1 parent 9ffbee3 commit 97d7344
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 70 deletions.
24 changes: 16 additions & 8 deletions src/OutputReaders/field_dataset.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
struct FieldDataset{F, M, P}
fields :: F
metadata :: M
filepath :: P
struct FieldDataset{F, M, P, KW}
fields :: F
metadata :: M
filepath :: P
reader_kw :: KW
end

"""
Expand All @@ -22,17 +23,24 @@ linearly.
`file["metadata"]`.
- `grid`: May be specified to override the grid used in the JLD2 file.
- `reader_kw`: A dictionary of keyword arguments to pass to the reader (currently only JLD2)
to be used when opening files.
"""
function FieldDataset(filepath;
architecture=CPU(), grid=nothing, backend=InMemory(), metadata_paths=["metadata"])
architecture = CPU(),
grid = nothing,
backend = InMemory(),
metadata_paths = ["metadata"],
reader_kw = Dict{Symbol, Any}())

file = jldopen(filepath)
file = jldopen(filepath; reader_kw...)

field_names = keys(file["timeseries"])
filter!(k -> k != "t", field_names) # Time is not a field.

ds = Dict{String, FieldTimeSeries}(
name => FieldTimeSeries(filepath, name; architecture, backend, grid)
name => FieldTimeSeries(filepath, name; architecture, backend, grid, reader_kw)
for name in field_names
)

Expand All @@ -44,7 +52,7 @@ function FieldDataset(filepath;

close(file)

return FieldDataset(ds, metadata, abspath(filepath))
return FieldDataset(ds, metadata, abspath(filepath), reader_kw)
end

Base.getindex(fds::FieldDataset, inds...) = Base.getindex(fds.fields, inds...)
Expand Down
84 changes: 50 additions & 34 deletions src/OutputReaders/field_time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ period = t[end] - t[1] + Δt
"""
struct Cyclical{FT}
period :: FT
end
end

Cyclical() = Cyclical(nothing)

Expand Down Expand Up @@ -164,7 +164,7 @@ Nt = 5
backend = InMemory(4, 3) # so we have (4, 5, 1)
n = 1 # so, the right answer is m̃ = 3
m = 1 - (4 - 1) # = -2
m̃ = mod1(-2, 5) # = 3 ✓
m̃ = mod1(-2, 5) # = 3 ✓
```
# Another shifting + wrapping example
Expand Down Expand Up @@ -213,7 +213,7 @@ Base.length(backend::PartlyInMemory) = backend.length
##### FieldTimeSeries
#####

mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: AbstractField{LX, LY, LZ, G, ET, 4}
mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4}
data :: D
grid :: G
backend :: K
Expand All @@ -223,16 +223,18 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
path :: P
name :: N
time_indexing :: TI

reader_kw :: KW

function FieldTimeSeries{LX, LY, LZ}(data::D,
grid::G,
backend::K,
bcs::B,
indices::I,
indices::I,
times,
path,
name,
time_indexing) where {LX, LY, LZ, K, D, G, B, I}
time_indexing,
reader_kw) where {LX, LY, LZ, K, D, G, B, I}

ET = eltype(data)

Expand All @@ -250,7 +252,7 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A

times = on_architecture(architecture(grid), times)
end

if time_indexing isa Cyclical{Nothing} # we have to infer the period
Δt = @allowscalar times[end] - times[end-1]
period = @allowscalar times[end] - times[1] + Δt
Expand All @@ -261,23 +263,25 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
TI = typeof(time_indexing)
P = typeof(path)
N = typeof(name)
KW = typeof(reader_kw)

return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N}(data, grid, backend, bcs,
indices, times, path, name,
time_indexing)
return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW}(data, grid, backend, bcs,
indices, times, path, name,
time_indexing, reader_kw)
end
end

on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} =
on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} =
FieldTimeSeries{LX, LY, LZ}(on_architecture(to, fts.data),
on_architecture(to, fts.grid),
on_architecture(to, fts.backend),
on_architecture(to, fts.bcs),
on_architecture(to, fts.indices),
on_architecture(to, fts.indices),
on_architecture(to, fts.times),
on_architecture(to, fts.path),
on_architecture(to, fts.name),
on_architecture(to, fts.time_indexing))
on_architecture(to, fts.time_indexing),
on_architecture(to, fts.reader_kw))

#####
##### Minimal implementation of FieldTimeSeries for use in GPU kernels
Expand All @@ -290,7 +294,7 @@ struct GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K, ET, D, χ} <: AbstractField{
times :: χ
backend :: K
time_indexing :: TI

function GPUAdaptedFieldTimeSeries{LX, LY, LZ}(data::D,
times:,
backend::K,
Expand All @@ -313,7 +317,7 @@ const FTS{LX, LY, LZ, TI, K} = FieldTimeSeries{LX, LY, LZ, TI, K} w
const GPUFTS{LX, LY, LZ, TI, K} = GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K} where {LX, LY, LZ, TI, K}

const FlavorOfFTS{LX, LY, LZ, TI, K} = Union{GPUFTS{LX, LY, LZ, TI, K},
FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K}
FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K}

const InMemoryFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:AbstractInMemoryBackend}
const OnDiskFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:OnDisk}
Expand Down Expand Up @@ -345,7 +349,7 @@ instantiate(T::Type) = T()
new_data(FT, grid, loc, indices, ::Nothing) = nothing

# Apparently, not explicitly specifying Int64 in here makes this function
# fail on x86 processors where `Int` is implied to be `Int32`
# fail on x86 processors where `Int` is implied to be `Int32`
# see ClimaOcean commit 3c47d887659d81e0caed6c9df41b7438e1f1cd52 at https://github.com/CliMA/ClimaOcean.jl/actions/runs/8804916198/job/24166354095)
function new_data(FT, grid, loc, indices, Nt::Union{Int, Int64})
space_size = total_size(grid, loc, indices)
Expand All @@ -360,12 +364,13 @@ time_indices_length(backend::PartlyInMemory, times) = length(backend)
time_indices_length(::OnDisk, times) = nothing

function FieldTimeSeries(loc, grid, times=();
indices = (:, :, :),
indices = (:, :, :),
backend = InMemory(),
path = nothing,
path = nothing,
name = nothing,
time_indexing = Linear(),
boundary_conditions = nothing)
boundary_conditions = nothing,
reader_kw = Dict{Symbol, Any}())

LX, LY, LZ = loc

Expand All @@ -376,9 +381,9 @@ function FieldTimeSeries(loc, grid, times=();
isnothing(path) && error(ArgumentError("Must provide the keyword argument `path` when `backend=OnDisk()`."))
isnothing(name) && error(ArgumentError("Must provide the keyword argument `name` when `backend=OnDisk()`."))
end
return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions,
indices, times, path, name, time_indexing)

return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices,
times, path, name, time_indexing, reader_kw)
end

"""
Expand All @@ -405,10 +410,16 @@ end
struct UnspecifiedBoundaryConditions end

"""
FieldTimeSeries(path, name, backend = InMemory();
FieldTimeSeries(path, name;
backend = InMemory(),
architecture = nothing,
grid = nothing,
location = nothing,
boundary_conditions = UnspecifiedBoundaryConditions(),
time_indexing = Linear(),
iterations = nothing,
times = nothing)
times = nothing,
reader_kw = Dict{Symbol, Any}())
Return a `FieldTimeSeries` containing a time-series of the field `name`
load from JLD2 output located at `path`.
Expand All @@ -427,6 +438,9 @@ Keyword arguments
- `times`: Save times to load, as determined through an approximate floating point
comparison to recorded save times. Defaults to times associated with `iterations`.
Takes precedence over `iterations` if `times` is specified.
- `reader_kw`: A dictionary of keyword arguments to pass to the reader (currently only JLD2)
to be used when opening files.
"""
function FieldTimeSeries(path::String, name::String;
backend = InMemory(),
Expand All @@ -436,9 +450,10 @@ function FieldTimeSeries(path::String, name::String;
boundary_conditions = UnspecifiedBoundaryConditions(),
time_indexing = Linear(),
iterations = nothing,
times = nothing)
times = nothing,
reader_kw = Dict{Symbol, Any}())

file = jldopen(path)
file = jldopen(path; reader_kw...)

# Defaults
isnothing(iterations) && (iterations = parse.(Int, keys(file["timeseries/t"])))
Expand Down Expand Up @@ -520,8 +535,8 @@ function FieldTimeSeries(path::String, name::String;
Nt = time_indices_length(backend, times)
data = new_data(eltype(grid), grid, loc, indices, Nt)

time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions,
indices, times, path, name, time_indexing)
time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices,
times, path, name, time_indexing, reader_kw)

set!(time_series, path, name)

Expand All @@ -533,7 +548,8 @@ end
grid = nothing,
architecture = nothing,
indices = (:, :, :),
boundary_conditions = nothing)
boundary_conditions = nothing,
reader_kw = Dict{Symbol, Any}())
Load a field called `name` saved in a JLD2 file at `path` at `iter`ation.
Unless specified, the `grid` is loaded from `path`.
Expand All @@ -542,7 +558,8 @@ function Field(location, path::String, name::String, iter;
grid = nothing,
architecture = nothing,
indices = (:, :, :),
boundary_conditions = nothing)
boundary_conditions = nothing,
reader_kw = Dict{Symbol, Any}())

# Default to CPU if neither architecture nor grid is specified
if isnothing(architecture)
Expand All @@ -552,9 +569,9 @@ function Field(location, path::String, name::String, iter;
architecture = Architectures.architecture(grid)
end
end

# Load the grid and data from file
file = jldopen(path)
file = jldopen(path; reader_kw...)

isnothing(grid) && (grid = file["serialized/grid"])
raw_data = file["timeseries/$name/$iter"]
Expand All @@ -565,7 +582,7 @@ function Field(location, path::String, name::String, iter;
grid = on_architecture(architecture, grid)
raw_data = on_architecture(architecture, raw_data)
data = offset_data(raw_data, grid, location, indices)

return Field(location, grid; boundary_conditions, indices, data)
end

Expand Down Expand Up @@ -625,4 +642,3 @@ function fill_halo_regions!(fts::InMemoryFTS)

return nothing
end

23 changes: 11 additions & 12 deletions src/OutputReaders/field_time_series_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import Oceananigans.Fields: interpolate
# Cyclical implementation if out-of-bounds (wrap around the time-series)
@inline function interpolating_time_indices(ti::Cyclical, times, t)
Nt = length(times)
= first(times)
= first(times)
tᴺ = last(times)

T = ti.period
Expand All @@ -32,14 +32,14 @@ import Oceananigans.Fields: interpolate
uncycled_indices = (ñ, n₁, n₂)

return ifelse(cycling, cycled_indices, uncycled_indices)
end
end

# Clamp mode if out-of-bounds, i.e get the neareast neighbor
@inline function interpolating_time_indices(::Clamp, times, t)
n, n₁, n₂ = time_index_binary_search(times, t)

beyond_indices = (0, n₂, n₂) # Beyond the last time: return n₂
before_indices = (0, n₁, n₁) # Before the first time: return n₁
before_indices = (0, n₁, n₁) # Before the first time: return n₁
unclamped_indices = (n, n₁, n₂) # Business as usual

Nt = length(times)
Expand All @@ -53,13 +53,13 @@ end
@inline function time_index_binary_search(times, t)
Nt = length(times)

# n₁ and n₂ are the index to interpolate inbetween and
# n₁ and n₂ are the index to interpolate inbetween and
# n is a fractional index where 0 ≤ n ≤ 1
n₁, n₂ = index_binary_search(times, t, Nt)

@inbounds begin
t₁ = times[n₁]
t₂ = times[n₂]
t₁ = times[n₁]
t₂ = times[n₂]
end

# "Fractional index" ñ ∈ (0, 1)
Expand All @@ -79,7 +79,7 @@ import Base: getindex
function getindex(fts::OnDiskFTS, n::Int)
# Load data
arch = architecture(fts)
file = jldopen(fts.path)
file = jldopen(fts.path; fts.reader_kw...)
iter = keys(file["timeseries/t"])[n]
raw_data = on_architecture(arch, file["timeseries/$(fts.name)/$iter"])
close(file)
Expand Down Expand Up @@ -117,7 +117,7 @@ const YZFTS = FlavorOfFTS{Nothing, <:Any, <:Any, <:Any, <:Any}

@inline function interpolating_getindex(fts, i, j, k, time_index)
ñ, n₁, n₂ = interpolating_time_indices(fts.time_indexing, fts.times, time_index.time)

@inbounds begin
ψ₁ = getindex(fts, i, j, k, n₁)
ψ₂ = getindex(fts, i, j, k, n₂)
Expand Down Expand Up @@ -229,14 +229,14 @@ end
##### FieldTimeSeries updating
#####

# Let's make sure `times` is available on the CPU. This is always the case
# for ranges. if `times` is a vector that resides on the GPU, it has to be moved to the CPU for safe indexing.
# Let's make sure `times` is available on the CPU. This is always the case
# for ranges. if `times` is a vector that resides on the GPU, it has to be moved to the CPU for safe indexing.
# TODO: Copying the whole array is a bit unclean, maybe find a way that avoids the penalty of allocating and copying memory.
# This would require refactoring `FieldTimeSeries` to include a cpu-allocated times array
cpu_interpolating_time_indices(::CPU, times, time_indexing, t, arch) = interpolating_time_indices(time_indexing, times, t)
cpu_interpolating_time_indices(::CPU, times::AbstractVector, time_indexing, t) = interpolating_time_indices(time_indexing, times, t)

function cpu_interpolating_time_indices(::GPU, times::AbstractVector, time_indexing, t)
function cpu_interpolating_time_indices(::GPU, times::AbstractVector, time_indexing, t)
cpu_times = on_architecture(CPU(), times)
return interpolating_time_indices(time_indexing, cpu_times, t)
end
Expand Down Expand Up @@ -279,4 +279,3 @@ function getindex(fts::InMemoryFTS, n::Int)

return Field(location(fts), fts.grid; data, fts.boundary_conditions, fts.indices)
end

Loading

0 comments on commit 97d7344

Please sign in to comment.