Skip to content

Commit

Permalink
Fix JLD2 ext
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Feb 7, 2024
2 parents 69bbf5f + b3f3772 commit 1046f25
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 18 deletions.
5 changes: 2 additions & 3 deletions ext/FluxJUDIExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module FluxJUDIExt

isdefined(Base, :get_extension) ? (using JUDI) : (using ..JUDI)
using Flux
import JUDI: LazyPropagation, judiVector, eval_prop
isdefined(Base, :get_extension) ? (using Flux) : (using ..Flux)

Flux.Zygote.unbroadcast(x::AbstractArray, x̄::LazyPropagation) = Zygote.unbroadcast(x, eval_prop(x̄))
Flux.cpu(x::LazyPropagation) = Flux.cpu(eval_prop(x))
Flux.gpu(x::LazyPropagation) = Flux.gpu(eval_prop(x))
Flux.CUDA.cu(F::LazyPropagation) = Flux.CUDA.cu(eval_prop(F))
Expand Down
11 changes: 5 additions & 6 deletions ext/JLD2JUDIExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
module JLD2JUDIExt

isdefined(Base, :get_extension) ? (using JUDI) : (using ..JUDI)
using JLD2
using JUDI
isdefined(Base, :get_extension) ? (using JLD2) : (using ..JLD2)

JLD2.rconvert(::Type{Geometry}, x::JLD2.ReconstructedMutable{N, FN, NT}) where {N, FN, NT} = Geometry([JUDI.tof32(getproperty(x, f)) for f in FN]...)
JUDI.Geometry(x::JLD2.ReconstructedMutable{N, FN, NT}) where {N, FN, NT} = Geometry([JUDI.tof32(getproperty(x, f)) for f in FN]...)

function JLD2.rconvert(::Type{Geometry}, x::JLD2.ReconstructedMutable{N, FN, NT}) where {N, FN, NT}
args = [JUDI.tof32(getproperty(x, f)) for f in FN]
return Geometry(args...)
end

function JUDI.tof32(x::JLD2.ReconstructedStatic{N, FN, NT}) where {N, FN, NT}
# Drop "typed" signature
Expand Down
4 changes: 2 additions & 2 deletions ext/ZygoteJUDIExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module ZygoteJUDIExt

isdefined(Base, :get_extension) ? (using JUDI) : (using ..JUDI)
using Zygote
import JUDI: LazyPropagation, judiVector, eval_prop
isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote)

Zygote.unbroadcast(x::AbstractArray, x̄::LazyPropagation) = Zygote.unbroadcast(x, eval_prop(x̄))

Expand Down
9 changes: 4 additions & 5 deletions src/JUDI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,22 @@ function __init__()

# Optional dependencies
@static if !isdefined(Base, :get_extension)

# JLD2 compat for loading older version of JUDI types
@require JLD2="033835bb-8acc-5ee8-8aae-3f567f8a3b3d" begin
@require JLD2="033835bb-8acc-5ee8-8aae-3f567f8a3819" begin
@info "JLD2 compat enabled"
include("../ext/JLD2JUDIExt.jl")
using JLD2JUDIExt
end

# Additional Zygote compat if in use
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
@info "Zygote compat enabled"
include("../ext/ZygoteJUDIExt.jl")
using ZygoteJUDIExt
end

# Additional Flux compat if in use
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
@info "Flux compat enabled"
include("../ext/FluxJUDIExt.jl")
using FluxJUDIExt
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/Types/GeometryStructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ get_nsrc(S::SeisCon) = length(S)
get_nsrc(S::Vector{SeisCon}) = length(S)
get_nsrc(S::SeisBlock) = length(unique(get_header(S, "FieldRecord")))

n_samples(g::GeometryOOC, nsrc::Integer) = sum([g.nrec[j]*get_nt(g, j) for j=1:nsrc])
n_samples(g::GeometryOOC, nsrc::Integer) = sum(g.nrec .* get_nt(g))
n_samples(g::GeometryIC, nsrc::Integer) = sum([length(g.xloc[j])*get_nt(g, j) for j=1:nsrc])
n_samples(g::Geometry) = n_samples(g, get_nsrc(g))

Expand Down
7 changes: 6 additions & 1 deletion src/TimeModeling/Utils/auxiliaryFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@ filter_none(x) = x

"""
_maybe_pad_t0(q, qGeom, data, dataGeom)
Pad zeros for data with non-zero t0, usually from a segy file so that time axis and array size match for the source and data.
"""
function _maybe_pad_t0(qIn::Matrix{T}, qGeom::Geometry, dObserved::Matrix{T}, dataGeom::Geometry) where T<:Number
if size(dObserved, 1) != size(qIn, 1)
Expand All @@ -823,4 +825,7 @@ function _maybe_pad_t0(qIn::Matrix{T}, qGeom::Geometry, dObserved::Matrix{T}, da
end
end
return qIn, dObserved
end
end

_maybe_pad_t0(qIn::judiVector{T, AT}, dObserved::judiVector{T, AT}) where{T<:Number, AT} =
_maybe_pad_t0(qIn.data, qIn.geometry, dObserved.data, dObserved.geometry)
5 changes: 5 additions & 0 deletions src/pysource/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from devito.data.allocators import ExternalAllocator
from devito.tools import as_tuple

try:
from devitopro import *
except ImportError:
pass


def wavefield(model, space_order, save=False, nt=None, fw=True, name='', t_sub=1):
"""
Expand Down
6 changes: 6 additions & 0 deletions src/pysource/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from devito.data.allocators import ExternalAllocator
from devito.tools import as_tuple, memoized_func

try:
from devitopro import *
except ImportError:
pass


__all__ = ['Model']


Expand Down
5 changes: 5 additions & 0 deletions src/pysource/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from sensitivity import grad_expr, lin_src
from utils import opt_op

try:
from devitopro import *
except ImportError:
pass


def name(model):
if model.is_tti:
Expand Down
1 change: 1 addition & 0 deletions test/test_issues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
# Load file with old judiVector type and julia <1.7 StepRangeLen
@load "$(datapath)backward_comp.jld" dat

@show dat.geometry
@test typeof(dat) == judiVector{Float32, Matrix{Float32}}
@test typeof(dat.geometry) == GeometryIC{Float32}
@test typeof(dat.geometry.xloc) == Vector{Vector{Float32}}
Expand Down

0 comments on commit 1046f25

Please sign in to comment.