From b7ea835c7d4cae8879d16665bebaf9797c17f83d Mon Sep 17 00:00:00 2001
From: mloubout <mathias.louboutin@gmail.com>
Date: Thu, 29 Aug 2024 10:35:18 -0500
Subject: [PATCH 1/5] Fix loss to use local data

---
 src/JUDI.jl                                | 83 ++--------------------
 src/TimeModeling/Types/OptionsStructure.jl | 58 ++++++---------
 src/pysource/sensitivity.py                | 24 ++++---
 src/pysource/sources.py                    |  5 +-
 4 files changed, 45 insertions(+), 125 deletions(-)

diff --git a/src/JUDI.jl b/src/JUDI.jl
index e70b72cf8..516f04c8c 100644
--- a/src/JUDI.jl
+++ b/src/JUDI.jl
@@ -57,99 +57,25 @@ import PyCall.NpyArray
 import ChainRulesCore: rrule
 
 # Set python paths
-export devito, set_devito_config
+export devito
+
 const pm = PyNULL()
 const ac = PyNULL()
 const pyut = PyNULL()
 const devito = PyNULL()
 
-set_devito_config(key::String, val::String) = set!(devito."configuration", key, val)
-set_devito_config(key::String, val::Bool) = set!(devito."configuration", key, val)
-
-# Create a lock for pycall FOR THREAD/TASK SAFETY
-# See discussion at
-# https://github.com/JuliaPy/PyCall.jl/issues/882
-
-const PYLOCK = Ref{ReentrantLock}()
-
-# acquire the lock before any code calls Python
-pylock(f::Function) = Base.lock(PYLOCK[]) do
-    prev_gc = GC.enable(false)
-    try 
-        return f()
-    finally
-        GC.enable(prev_gc) # recover previous state
-    end
-end
-
-function rlock_pycall(meth, ::Type{T}, args...; kw...) where T
-    out::T = pylock() do
-        pycall(meth, T, args...; kw...)
-    end
-    return out
-end
-
 # Constants
-_serial = false
-get_serial() = _serial
-set_serial(x::Bool) = begin global _serial = x; end
-set_serial() = begin global _serial = true; end
-set_parallel() = begin global _serial = false; end
-
-function _worker_pool()
-    if _serial
-        return nothing
-    end
-    p = default_worker_pool()
-    pool = nworkers(p) < 2 ? nothing : p
-    return pool
-end
-
 nworkers(::Any) = length(workers())
 
 _TFuture = Future
 _verbose = false
 _devices = []
 
-# Utility for data loading
-JUDI_DATA = joinpath(JUDIPATH, "../data")
-ftp_data(ftp::String, name::String) = Base.Downloads().download("$(ftp)/$(name)", "$(JUDI.JUDI_DATA)/$(name)")
-ftp_data(ftp::String) = Base.Downloads().download(ftp, "$(JUDI.JUDI_DATA)/$(split(ftp, "/")[end])")
-
 # Some usefull types
 const RangeOrVec = Union{AbstractRange, Vector}
 
-set_verbosity(x::Bool) = begin global _verbose = x; end
-judilog(msg) = _verbose ? printstyled("JUDI: $(msg) \n", color=:magenta) : nothing
-
-function human_readable_time(t::Float64, decimals=2)
-    units = ["ns", "μs", "ms", "s", "min", "hour"]
-    scales = [1e-9, 1e-6, 1e-3, 1, 60, 3600]
-    if t < 1e-9
-        tr = round(t/1e-9; sigdigits=decimals)
-        return "$(tr) ns"
-    end
-
-    for i=2:6
-        if t < scales[i]
-            tr = round(t/scales[i-1]; sigdigits=decimals)
-            return "$(tr) $(units[i-1])"
-        end
-    end
-    tr1 = div(t, 3600)
-    tr2 = round(Int, rem(t, 3600))
-    return "$(tr1) h $(tr2) min"
-end 
-
-
-macro juditime(msg, ex)
-    return quote
-       local t
-       t = @elapsed $(esc(ex))
-       tr = human_readable_time(t)
-       judilog($(esc(msg))*": $(tr)")
-    end
-end
+# Utils
+include("utilities.jl")
 
 # JUDI time modeling
 include("TimeModeling/TimeModeling.jl")
@@ -170,6 +96,7 @@ include("compat.jl")
 # Automatic Differentiation
 include("rrules.jl")
 
+
 # Initialize
 function __init__()
     pushfirst!(PyVector(pyimport("sys")."path"), joinpath(JUDIPATH, "pysource"))
diff --git a/src/TimeModeling/Types/OptionsStructure.jl b/src/TimeModeling/Types/OptionsStructure.jl
index f7af882a6..24f3eb79e 100644
--- a/src/TimeModeling/Types/OptionsStructure.jl
+++ b/src/TimeModeling/Types/OptionsStructure.jl
@@ -99,41 +99,29 @@ All arguments are optional keyword arguments with the following default values:
             dt_comp=nothing, f0=0.015f0)
 
 """
-Options(;space_order=8,
-		 free_surface=false,
-         limit_m=false,
-		 buffer_size=1e3,
-		 save_data_to_disk=false,
-		 file_path="",
-		 file_name="shot",
-         sum_padding=false,
-		 optimal_checkpointing=false,
-		 num_checkpoints=nothing,
-		 checkpoints_maxmem=nothing,
-		 frequencies=[],
-		 isic=false,
-		 subsampling_factor=1,
-		 dft_subsampling_factor=1,
-         return_array=false,
-         dt_comp=nothing,
-         f0=0.015f0,
-         IC="as") =
-		 JUDIOptions(space_order,
-		 		 free_surface,
-		         limit_m,
-				 buffer_size,
-				 save_data_to_disk,
-				 file_path,
-				 file_name,
-				 sum_padding,
-				 optimal_checkpointing,
-				 frequencies,
-				 imcond(isic, IC),
-				 subsampling_factor,
-				 dft_subsampling_factor,
-                 return_array,
-                 dt_comp,
-                 f0)
+function Options(;space_order=8, free_surface=false,
+                  limit_m=false, buffer_size=1e3,
+		          save_data_to_disk=false, file_path="", file_name="shot",
+                  sum_padding=false,
+		          optimal_checkpointing=false,
+		          frequencies=[],
+		          subsampling_factor=1,
+		          dft_subsampling_factor=1,
+                  return_array=false,
+                  dt_comp=nothing,
+                  f0=0.015f0,
+                  IC="as",
+                  kw...)
+    
+    ic = imcond(get(kw, :isic, false), IC)
+    if optimal_checkpointing && get(ENV, "DEVITO_DECOUPLER", 0) != 0
+        @warn "Optimal checkpointing is not supported with the Decoupler, disabling"
+        optimal_checkpointing = false
+    end
+    return JUDIOptions(space_order, free_surface, limit_m, buffer_size, save_data_to_disk,
+                file_path, file_name, sum_padding, optimal_checkpointing, frequencies,
+                ic, subsampling_factor, dft_subsampling_factor, return_array, dt_comp, f0)
+end
 
 JUDIOptions(;kw...) = Options(kw...)
 
diff --git a/src/pysource/sensitivity.py b/src/pysource/sensitivity.py
index 445bcc212..a09da78ca 100644
--- a/src/pysource/sensitivity.py
+++ b/src/pysource/sensitivity.py
@@ -257,21 +257,23 @@ def Loss(dsyn, dobs, dt, is_residual=False, misfit=None):
     """
     if misfit is not None:
         if isinstance(dsyn, tuple):
-            f, r = misfit(dsyn[0].data, dobs[:] - dsyn[1].data[:])
-            dsyn[0].data[:] = r[:]
-            return dt * f, dsyn[0].data
+            f, r = misfit(dsyn[0].data._local, dobs[:] - dsyn[1].data._local[:])
+            dsyn[0].data._local[:] = r[:]
+            return dt * f, dsyn[0].data._local
         else:
-            f, r = misfit(dsyn.data, dobs)
-            dsyn.data[:] = r[:]
-            return dt * f, dsyn.data
+            f, r = misfit(dsyn.data._local, dobs)
+            dsyn.data._local[:] = r[:]
+            return dt * f, dsyn.data._local
 
     if not is_residual:
         if isinstance(dsyn, tuple):
-            dsyn[0].data[:] -= dobs[:] - dsyn[1].data[:]  # input is observed data
-            return .5 * dt * np.linalg.norm(dsyn[0].data)**2, dsyn[0].data
+            # input is observed data
+            dsyn[0].data._local[:] -= dobs[:] - dsyn[1].data._local[:]
+            phi = .5 * dt * np.linalg.norm(dsyn[0].data._local)**2
+            return phi, dsyn[0].data._local
         else:
-            dsyn.data[:] -= dobs[:]   # input is observed data
+            dsyn.data._local[:] -= dobs[:]   # input is observed data
     else:
-        dsyn.data[:] = dobs[:]
+        dsyn.data._local[:] = dobs[:]
 
-    return .5 * dt * np.linalg.norm(dsyn.data)**2, dsyn.data
+    return .5 * dt * np.linalg.norm(dsyn.data._local)**2, dsyn.data._local
diff --git a/src/pysource/sources.py b/src/pysource/sources.py
index cf7bb72ff..e0a67fbd6 100644
--- a/src/pysource/sources.py
+++ b/src/pysource/sources.py
@@ -93,6 +93,8 @@ class PointSource(SparseTimeFunction):
     initialised `data` array need to be provided.
     """
 
+    __rkwargs__ = list(SparseTimeFunction.__rkwargs__)
+
     @classmethod
     def __args_setup__(cls, *args, **kwargs):
         if 'nt' not in kwargs:
@@ -102,10 +104,11 @@ def __args_setup__(cls, *args, **kwargs):
                 kwargs['nt'] = kwargs.get('time').shape[0]
 
         # Either `npoint` or `coordinates` must be provided
-        npoint = kwargs.get('npoint')
+        npoint = kwargs.get('npoint', kwargs.get('npoint_global'))
         if npoint is None:
             coordinates = kwargs.get('coordinates', kwargs.get('coordinates_data'))
             if coordinates is None:
+                print(kwargs)
                 raise TypeError("Need either `npoint` or `coordinates`")
             kwargs['npoint'] = coordinates.shape[0]
 

From d0e1ad994983fbdc27fa8d46d60b81556ed6339e Mon Sep 17 00:00:00 2001
From: mloubout <mathias.louboutin@gmail.com>
Date: Thu, 29 Aug 2024 10:41:54 -0500
Subject: [PATCH 2/5] add some utilities

---
 src/utilities.jl | 101 +++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 101 insertions(+)
 create mode 100644 src/utilities.jl

diff --git a/src/utilities.jl b/src/utilities.jl
new file mode 100644
index 000000000..d6714181a
--- /dev/null
+++ b/src/utilities.jl
@@ -0,0 +1,101 @@
+export set_devito_config, ftp_data, set_serial, set_parallel, set_verbosity
+export devito_omp, devito_icx, devito_acc, devito_nvc_host, devito_cuda, devito_sycl, devito_hip
+
+# Logging utilities
+set_verbosity(x::Bool) = begin global _verbose = x; end
+judilog(msg) = _verbose ? printstyled("JUDI: $(msg) \n", color=:magenta) : nothing
+
+function human_readable_time(t::Float64, decimals=2)
+    units = ["ns", "μs", "ms", "s", "min", "hour"]
+    scales = [1e-9, 1e-6, 1e-3, 1, 60, 3600]
+    if t < 1e-9
+        tr = round(t/1e-9; sigdigits=decimals)
+        return "$(tr) ns"
+    end
+
+    for i=2:6
+        if t < scales[i]
+            tr = round(t/scales[i-1]; sigdigits=decimals)
+            return "$(tr) $(units[i-1])"
+        end
+    end
+    tr1 = div(t, 3600)
+    tr2 = round(Int, rem(t, 3600))
+    return "$(tr1) h $(tr2) min"
+end 
+
+
+macro juditime(msg, ex)
+    return quote
+       local t
+       t = @elapsed $(esc(ex))
+       tr = human_readable_time(t)
+       judilog($(esc(msg))*": $(tr)")
+    end
+end
+
+
+# Utility for data loading
+JUDI_DATA = joinpath(JUDIPATH, "../data")
+ftp_data(ftp::String, name::String) = Base.Downloads().download("$(ftp)/$(name)", "$(JUDI.JUDI_DATA)/$(name)")
+ftp_data(ftp::String) = Base.Downloads().download(ftp, "$(JUDI.JUDI_DATA)/$(split(ftp, "/")[end])")
+
+
+# Parallelism
+_serial = false
+get_serial() = _serial
+set_serial(x::Bool) = begin global _serial = x; end
+set_serial() = begin global _serial = true; end
+set_parallel() = begin global _serial = false; end
+
+function _worker_pool()
+    if _serial
+        return nothing
+    end
+    p = default_worker_pool()
+    pool = nworkers(p) < 2 ? nothing : p
+    return pool
+end
+
+
+# Create a lock for pycall FOR THREAD/TASK SAFETY
+# See discussion at
+# https://github.com/JuliaPy/PyCall.jl/issues/882
+
+const PYLOCK = Ref{ReentrantLock}()
+
+# acquire the lock before any code calls Python
+pylock(f::Function) = Base.lock(PYLOCK[]) do
+    prev_gc = GC.enable(false)
+    try 
+        return f()
+    finally
+        GC.enable(prev_gc) # recover previous state
+    end
+end
+
+function rlock_pycall(meth, ::Type{T}, args...; kw...) where T
+    out::T = pylock() do
+        pycall(meth, T, args...; kw...)
+    end
+    return out
+end
+
+# Devito configuration
+set_devito_config(key::String, val::String) = set!(devito."configuration", key, val)
+set_devito_config(key::String, val::Bool) = set!(devito."configuration", key, val)
+
+set_devito_config(kw...) = begin
+    for (k, v) in kw
+        set_devito_config(k, v)
+    end
+end
+
+# Easy configurations setupes
+devito_omp() = set_devito_config("language", "openmp")
+devito_icx() = set_devito_config(language="openmp", compiler="icx")
+devito_acc() = set_devito_config(language="openacc", compiler="nvc", platform="nvidiaX")
+devito_nvc_host() = set_devito_config(language="openmp", compiler="nvc")
+devito_cuda() = set_devito_config(language="cuda", platform="nvidiaX")
+devito_sycl() = set_devito_config(language="sycl", platform="intelgpuX")
+devito_hip() = set_devito_config(language="hip", platform="amdgpuX")

From 9c29fcf50fd05a289e9927ecf1e61373cf2b19ec Mon Sep 17 00:00:00 2001
From: mloubout <mathias.louboutin@gmail.com>
Date: Thu, 19 Sep 2024 11:00:29 -0400
Subject: [PATCH 3/5] support float16 model params

---
 src/JUDI.jl                                   |   2 +-
 src/TimeModeling/LinearOperators/operators.jl |   4 +-
 src/TimeModeling/Modeling/propagation.jl      |   6 +-
 .../Preconditioners/DataPreconditioners.jl    |   6 +-
 src/TimeModeling/TimeModeling.jl              |   2 +
 src/TimeModeling/Types/ModelStructure.jl      |  57 +++---
 src/TimeModeling/Types/OptionsStructure.jl    |   9 +-
 src/pysource/fields_exprs.py                  |   8 +-
 src/pysource/geom_utils.py                    |  16 +-
 src/pysource/kernels.py                       |  40 ++--
 src/pysource/models.py                        | 179 +++++++++---------
 11 files changed, 189 insertions(+), 140 deletions(-)

diff --git a/src/JUDI.jl b/src/JUDI.jl
index 516f04c8c..2a8c2ba58 100644
--- a/src/JUDI.jl
+++ b/src/JUDI.jl
@@ -15,7 +15,7 @@ if !isdefined(Base, :get_extension)
 end
 
 # Dependencies
-using LinearAlgebra, Random
+using LinearAlgebra, Random, Printf
 using Distributed
 using DSP, FFTW, Dierckx
 using PyCall
diff --git a/src/TimeModeling/LinearOperators/operators.jl b/src/TimeModeling/LinearOperators/operators.jl
index 715ce6fc7..3f0130737 100644
--- a/src/TimeModeling/LinearOperators/operators.jl
+++ b/src/TimeModeling/LinearOperators/operators.jl
@@ -174,7 +174,7 @@ adjoint(L::LazyScal) = LazyScal(L.s, adjoint(L.P))
 *(F::judiPropagator{T, O}, q::judiMultiSourceVector{T}) where {T<:Number, O} = multi_src_propagate(F, q)
 *(F::judiPropagator{T, O}, q::AbstractVector{T}) where {T<:Number, O} = multi_src_propagate(F, q)
 *(F::judiPropagator{T, O}, q::DenseArray{T}) where {T<:Number, O} = multi_src_propagate(F, q)
-*(F::judiAbstractJacobian{T, O, FT}, q::dmType{T}) where {T<:Number, O, FT} = multi_src_propagate(F, q)
+*(F::judiAbstractJacobian{T, O, FT}, q::dmType{Tq}) where {T<:Number, Tq<:Pdtypes, O, FT} = multi_src_propagate(F, q)
 
 mul!(out::SourceType{T}, F::judiPropagator{T, O}, q::SourceType{T}) where {T<:Number, O} = begin y = F*q; copyto!(out, y) end
 mul!(out::SourceType{T}, F::judiAbstractJacobian{T, :born, FT}, q::Vector{T}) where {T<:Number, FT} = begin y = F*q[:]; copyto!(out, y) end
@@ -208,7 +208,7 @@ make_input(F::judiDataModeling, rhs::judiRHS) = (make_src(rhs)..., F.rInterpolat
 make_input(F::judiDataSourceModeling, q::SourceType{T}) where {T} = (make_src(q, F.qInjection)..., F.rInterpolation.data[1], nothing, nothing)
 make_input(F::judiDataSourceModeling, q::Matrix{T}) where {T} = (F.qInjection.data[1], q, F.rInterpolation.data[1], nothing, nothing)
 
-function make_input(J::judiJacobian{D, :born, FT}, q::dmType{D}) where {D<:Number, FT}
+function make_input(J::judiJacobian{D, :born, FT}, q::dmType{Dq}) where {D<:Number, Dq<:Pdtypes, FT}
     srcGeom, srcData = make_src(J.q, J.F.qInjection)
     return srcGeom, srcData, J.F.rInterpolation.data[1], nothing, reshape(q, size(J.model))
 end 
diff --git a/src/TimeModeling/Modeling/propagation.jl b/src/TimeModeling/Modeling/propagation.jl
index 90ed4b2fb..1ab812f6b 100644
--- a/src/TimeModeling/Modeling/propagation.jl
+++ b/src/TimeModeling/Modeling/propagation.jl
@@ -4,7 +4,7 @@
 Base propagation interfaces that calls the devito `mode` propagator (forward/adjoint/..)
 with `q` as a source. The return type is infered from `F`.
 """
-function propagate(F::judiPropagator{T, O}, q::AbstractArray{T}, illum::Bool) where {T, O}
+function propagate(F::judiPropagator{T, O}, q::AbstractArray{Tq}, illum::Bool) where {T, Tq, O}
     srcGeometry, srcData, recGeometry, recData, dm = make_input(F, q)
     return time_modeling(F.model, srcGeometry, srcData, recGeometry, recData, dm, O, F.options, _prop_fw(F), illum)
 end
@@ -54,7 +54,7 @@ _prop_fw(::judiPropagator{T, :adjoint}) where T = false
 _prop_fw(J::judiJacobian) = _prop_fw(J.F)
 
 
-src_i(::judiAbstractJacobian{T, :born, FT}, q::dmType{T}, ::Integer) where {T<:Number, FT} = q
+src_i(::judiAbstractJacobian{T, :born, FT}, q::dmType{Tq}, ::Integer) where {T<:Number, Tq<:Pdtypes, FT} = q
 src_i(::judiPropagator{T, O}, q::judiMultiSourceVector{T}, i::Integer) where {T, O} = q[i]
 src_i(::judiPropagator{T, O}, q::Vector{<:Array{T}}, i::Integer) where {T, O} = q[i]
 
@@ -68,7 +68,7 @@ get_nsrc(J::judiAbstractJacobian, ::dmType{T}) where T<:Number = J.q.nsrc
 Propagates the source `q` with the `F` propagator. The return type is infered from `F` and the 
 propagation kernel is defined by `O` (forward, adjoint, born or adjoint_born).
 """
-function multi_src_propagate(F::judiPropagator{T, O}, q::AbstractArray{T}) where {T<:Number, O}
+function multi_src_propagate(F::judiPropagator{T, O}, q::AbstractArray{Tq}) where {T<:Number, Tq<:Pdtypes, O}
     q = process_input_data(F, q)
     # Number of sources and init result
     nsrc = get_nsrc(F, q)
diff --git a/src/TimeModeling/Preconditioners/DataPreconditioners.jl b/src/TimeModeling/Preconditioners/DataPreconditioners.jl
index a23eee0a9..59346a280 100644
--- a/src/TimeModeling/Preconditioners/DataPreconditioners.jl
+++ b/src/TimeModeling/Preconditioners/DataPreconditioners.jl
@@ -198,9 +198,9 @@ transpose(I::FrequencyFilter{T}) where T = I
 
 function tracefilt!(x, y, ypad, filter)
     n = length(y)
-    ypad[1:n] .= y
-    ypad[n:end] .= view(y, n:-1:1)
-    x .= filtfilt(filter, ypad)[1:n]
+    ypad[n:end] .= y
+    ypad[1:n] .= view(y, n:-1:1)
+    x .= filtfilt(filter, ypad)[n:end]
     nothing
 end
 
diff --git a/src/TimeModeling/TimeModeling.jl b/src/TimeModeling/TimeModeling.jl
index 38686c2be..cd72482b1 100644
--- a/src/TimeModeling/TimeModeling.jl
+++ b/src/TimeModeling/TimeModeling.jl
@@ -9,6 +9,8 @@ include("LinearOperators/basics.jl")
 
 #############################################################################
 # Containers
+const Pdtypes = Union{Float32, Float16}
+
 include("Types/ModelStructure.jl")    # model container
 include("Types/GeometryStructure.jl") # source or receiver setup, recording time and sampling
 include("Types/OptionsStructure.jl")
diff --git a/src/TimeModeling/Types/ModelStructure.jl b/src/TimeModeling/Types/ModelStructure.jl
index 66819c29b..cedf675b7 100644
--- a/src/TimeModeling/Types/ModelStructure.jl
+++ b/src/TimeModeling/Types/ModelStructure.jl
@@ -274,41 +274,41 @@ end
 NpyArray(p::PhysicalParameter{T, N}, revdims::Bool) where {T<:Real, N} = NpyArray(p.data, revdims)
 
 ###################################################################################################
-const ModelParam{T, N} = Union{T, PhysicalParameter{T, N}}
+const ModelParam{N} = Union{<:Pdtypes, PhysicalParameter{<:Pdtypes, N}}
 
 # Acoustic
 struct IsoModel{T, N} <: AbstractModel{T, N}
     G::DiscreteGrid{T, N}
-    m::ModelParam{T, N}
-    rho::ModelParam{T, N}
+    m::ModelParam{N}
+    rho::ModelParam{N}
 end
 
 # VTI/TTI
 struct TTIModel{T, N} <: AbstractModel{T, N}
     G::DiscreteGrid{T, N}
-    m::ModelParam{T, N}
-    rho::ModelParam{T, N}
-    epsilon::ModelParam{T, N}
-    delta::ModelParam{T, N}
-    theta::ModelParam{T, N}
-    phi::ModelParam{T, N}
+    m::ModelParam{N}
+    rho::ModelParam{N}
+    epsilon::ModelParam{N}
+    delta::ModelParam{N}
+    theta::ModelParam{N}
+    phi::ModelParam{N}
 end
 
 # Elastic
 
 struct IsoElModel{T, N} <: AbstractModel{T, N}
     G::DiscreteGrid{T, N}
-    lam::ModelParam{T, N}
-    mu::ModelParam{T, N}
-    b::ModelParam{T, N}
+    lam::ModelParam{N}
+    mu::ModelParam{N}
+    b::ModelParam{N}
 end
 
 # Visco-acoustic
 struct ViscIsoModel{T, N} <: AbstractModel{T, N}
     G::DiscreteGrid{T, N}
-    m::ModelParam{T, N}
-    rho::ModelParam{T, N}
-    qp::ModelParam{T, N}
+    m::ModelParam{N}
+    rho::ModelParam{N}
+    qp::ModelParam{N}
 end
 
 _params(m::IsoModel) = ((:m, m.m), (:rho, m.rho))
@@ -355,7 +355,7 @@ function Model(d, o, m::Array{mT, N}; epsilon=nothing, delta=nothing, theta=noth
                phi=nothing, rho=nothing, qp=nothing, vs=nothing, nb=40) where {mT<:Real, N}
 
     # Currently force single precision
-    m = convert(Array{Float32, N}, m)
+    m = as_Pdtype(m)
     T = Float32
     # Convert dimension to internal types
     n = size(m)
@@ -371,9 +371,9 @@ function Model(d, o, m::Array{mT, N}; epsilon=nothing, delta=nothing, theta=noth
         if any(!isnothing(p) for p in [epsilon, delta, theta, phi])
             @warn "Thomsen parameters no supported for elastic (vs) ignoring them"
         end
-        lambda = PhysicalParameter(convert(Array{T, N}, (m.^(-1) .- T(2) .* vs.^2) .* rho), n, d, o)
-        mu = PhysicalParameter(convert(Array{T, N}, vs.^2 .* rho), n, d, o)
-        b = isa(rho, Array) ? PhysicalParameter(convert(Array{T, N}, 1 ./ rho), n, d, o) : _scalar(rho, T)
+        lambda = PhysicalParameter(as_Pdtype((m.^(-1) .- T(2) .* vs.^2) .* rho), n, d, o)
+        mu = PhysicalParameter(as_Pdtype(vs.^2 .* rho), n, d, o)
+        b = isa(rho, Array) ? PhysicalParameter(as_Pdtype(1 ./ rho), n, d, o) : _scalar(rho, T)
         return IsoElModel{T, N}(G, lambda, mu, b)
     end
 
@@ -382,9 +382,9 @@ function Model(d, o, m::Array{mT, N}; epsilon=nothing, delta=nothing, theta=noth
         if any(!isnothing(p) for p in [epsilon, delta, theta, phi])
             @warn "Thomsen parameters no supported for elastic (vs) ignoring them"
         end
-        qp = isa(qp, Array) ? PhysicalParameter(convert(Array{T, N}, qp), n, d, o)  : _scalar(qp, T)
+        qp = isa(qp, Array) ? PhysicalParameter(as_Pdtype(qp), n, d, o)  : _scalar(qp, T)
         m = PhysicalParameter(m, n, d, o)
-        rho = isa(rho, Array) ? PhysicalParameter(convert(Array{T, N}, rho), n, d, o) : _scalar(rho, T)
+        rho = isa(rho, Array) ? PhysicalParameter(as_Pdtype(rho), n, d, o) : _scalar(rho, T)
         return ViscIsoModel{T, N}(G, m, rho, qp)
     end
 
@@ -394,22 +394,25 @@ function Model(d, o, m::Array{mT, N}; epsilon=nothing, delta=nothing, theta=noth
             @warn "Elastic (vs) and attenuation (qp) not supported for TTI/VTI"
         end
         m = PhysicalParameter(m, n, d, o)
-        rho = isa(rho, Array) ? PhysicalParameter(convert(Array{T, N}, rho), n, d, o) : _scalar(rho, T)
-        epsilon = isa(epsilon, Array) ? PhysicalParameter(convert(Array{T, N}, epsilon), n, d, o) : _scalar(epsilon, T, 0)
-        delta = isa(delta, Array) ? PhysicalParameter(convert(Array{T, N}, delta), n, d, o) : _scalar(delta, T, 0)
+        rho = isa(rho, Array) ? PhysicalParameter(as_Pdtype(rho), n, d, o) : _scalar(rho, T)
+        epsilon = isa(epsilon, Array) ? PhysicalParameter(as_Pdtype(epsilon), n, d, o) : _scalar(epsilon, T, 0)
+        delta = isa(delta, Array) ? PhysicalParameter(as_Pdtype(delta), n, d, o) : _scalar(delta, T, 0)
         # For safety remove delta values unsupported (delta > epsilon)
         _clip_delta!(delta, epsilon)
-        theta = isa(theta, Array) ? PhysicalParameter(convert(Array{T, N}, theta), n, d, o) : _scalar(theta, T, 0)
-        phi = isa(phi, Array) ? PhysicalParameter(convert(Array{T, N}, phi), n, d, o) : _scalar(phi, T, 0)
+        theta = isa(theta, Array) ? PhysicalParameter(as_Pdtype(theta), n, d, o) : _scalar(theta, T, 0)
+        phi = isa(phi, Array) ? PhysicalParameter(as_Pdtype(phi), n, d, o) : _scalar(phi, T, 0)
         return TTIModel{T, N}(G, m, rho, epsilon, delta, theta, phi)
     end
 
     # None of the advanced models, return isotropic acoustic
     m = PhysicalParameter(m, n, d, o)
-    rho = isa(rho, Array) ? PhysicalParameter(convert(Array{T, N}, rho), n, d, o) : _scalar(rho, T)
+    rho = isa(rho, Array) ? PhysicalParameter(as_Pdtype(rho), n, d, o) : _scalar(rho, T)
     return IsoModel{T, N}(G, m, rho)
 end
 
+as_Pdtype(x::Array{T, N}) where {T<:Pdtypes, N} = x
+as_Pdtype(x::Array{T, N}) where {T, N} = convert(Array{Float32, N}, x)
+
 Model(n, d, o, m::Array, rho::Array; nb=40) = Model(d, o, reshape(m, n...); rho=reshape(rho, n...), nb=nb)
 Model(n, d, o, m::Array, rho::Array, qp::Array; nb=40) = Model(d, o, reshape(m, n...); rho=reshape(rho, n...), qp=reshape(qp, n...), nb=nb)
 Model(n, d, o, m::Array; kw...) = Model(d, o, reshape(m, n...); kw...)
diff --git a/src/TimeModeling/Types/OptionsStructure.jl b/src/TimeModeling/Types/OptionsStructure.jl
index 24f3eb79e..73f1dc06a 100644
--- a/src/TimeModeling/Types/OptionsStructure.jl
+++ b/src/TimeModeling/Types/OptionsStructure.jl
@@ -153,4 +153,11 @@ function imcond(isic::Bool, IC::String)
         return "isic"
     end
     return lowercase(IC)
-end
\ No newline at end of file
+end
+
+function Base.show(io::IO, options::JUDIOptions)
+    println(io, "JUDI Options : \n")
+    for f in fieldnames(JUDIOptions)
+        println(io, @sprintf("%-25s : %10s", f, getfield(options, f)))
+    end
+end
diff --git a/src/pysource/fields_exprs.py b/src/pysource/fields_exprs.py
index 0f49b110a..09ba14536 100644
--- a/src/pysource/fields_exprs.py
+++ b/src/pysource/fields_exprs.py
@@ -101,7 +101,7 @@ def extended_rec(model, wavelet, v):
     return [Inc(ws, model.grid.time_dim.spacing * wf * wt)]
 
 
-def freesurface(model, fields):
+def freesurface(model, fields, r_coeff=-1):
     """
     Generate the stencil that mirrors the field as a free surface modeling for
     the acoustic wave equation
@@ -122,8 +122,10 @@ def freesurface(model, fields):
         if u == 0:
             continue
 
-        eqs.extend([Eq(u._subs(z, - zfs), - u._subs(z, zfs)),
-                    Eq(u._subs(z, 0), 0)])
+        sh = 1 if z in as_tuple(u.staggered) else 0
+        eqs.extend([Eq(u._subs(z, - zfs), r_coeff * u._subs(z, zfs - sh))])
+        if z not in as_tuple(u.staggered):
+            eqs.append(Eq(u._subs(z, 0), 0))
 
     return eqs
 
diff --git a/src/pysource/geom_utils.py b/src/pysource/geom_utils.py
index c53677c91..5f8e49c06 100644
--- a/src/pysource/geom_utils.py
+++ b/src/pysource/geom_utils.py
@@ -4,6 +4,12 @@
 
 from sources import *
 
+try:
+    from recipes.utils import mirror_source
+except ImportError:
+    def mirror_source(src):
+        return src
+
 
 def src_rec(model, u, src_coords=None, rec_coords=None, wavelet=None, nt=None):
     nt = nt or wavelet.shape[0]
@@ -14,12 +20,12 @@ def src_rec(model, u, src_coords=None, rec_coords=None, wavelet=None, nt=None):
             src = wavelet
         else:
             src = PointSource(name="src%s" % namef, grid=model.grid, ntime=nt,
-                              coordinates=src_coords)
+                              coordinates=src_coords, interpolation='sinc', r=3)
             src.data[:] = wavelet.view(np.ndarray) if wavelet is not None else 0.
     rcv = None
     if rec_coords is not None:
         rcv = Receiver(name="rcv%s" % namef, grid=model.grid, ntime=nt,
-                       coordinates=rec_coords)
+                       coordinates=rec_coords, interpolation='sinc', r=3)
     return src, rcv
 
 
@@ -67,7 +73,11 @@ def geom_expr(model, u, src_coords=None, rec_coords=None, wavelet=None, fw=True,
         else:
             # Acoustic inject into pressure
             u_n = as_tuple(u)[0].forward if fw else as_tuple(u)[0].backward
-            geom_expr += src.inject(field=u_n, expr=src*dt**2/m)
+            src_eq = src.inject(field=u_n, expr=src*dt**2/m)
+            if model.fs:
+                # Free surface
+                src_eq = mirror_source(model, src_eq)
+            geom_expr += src_eq
     # Setup adjoint wavefield sampling at source locations
     if rcv is not None:
         if model.is_elastic:
diff --git a/src/pysource/kernels.py b/src/pysource/kernels.py
index dd471483c..dc368025c 100644
--- a/src/pysource/kernels.py
+++ b/src/pysource/kernels.py
@@ -6,6 +6,15 @@
 from fields_exprs import freesurface
 from FD_utils import laplacian, sa_tti
 
+try:
+    from recipes import recipes_registry
+    from recipes.utils import vs_mask_derivs
+except ImportError:
+    recipes_registry = {}
+
+    def vs_mask_derivs(eq, tau, vs):
+        return eq
+
 
 def wave_kernel(model, u, fw=True, q=None, f0=0.015):
     """
@@ -190,18 +199,24 @@ def elastic_kernel(model, v, tau, fw=True, q=None):
     q : TimeFunction or Expr
         Full time-space source as a tuple (one value for each component)
     """
-    if not fw:
-        raise NotImplementedError("Only forward modeling for the elastic equation")
-
     # Lame parameters
-    lam, b = model.lam, model.irho
+    lam, b, damp = model.lam, model.irho, model.damp
     try:
         mu = model.mu
     except AttributeError:
         mu = 0
 
+    # Time derivative and update
+    vnext = v.forward if fw else v.backward
+    taunext = tau.forward if fw else tau.backward
+
+    vdt = v.dt if fw else -v.dtl
+    tau_dt = tau.dt if fw else -tau.dtl
+
     # Particle velocity
-    eq_v = v.dt - b * div(tau)
+    eq_v = vdt - b * div(tau) + damp * vnext
+    eq_v = vs_mask_derivs(eq_v, tau, model.mu)
+
     # Stress
     try:
         e = (grad(v.forward) + grad(v.forward).transpose(inner=False))
@@ -209,16 +224,17 @@ def elastic_kernel(model, v, tau, fw=True, q=None):
         # Older devito version
         e = (grad(v.forward) + grad(v.forward).T)
 
-    eq_tau = tau.dt - lam * diag(div(v.forward)) - mu * e
+    eq_tau = tau_dt - lam * diag(div(v.forward)) - mu * e + damp * taunext
+
+    u_v = [Eq(vnext, solve(eq_v, vnext), subdomain=model.physical)]
+    if model.fs:
+        u_v.extend(freesurface(model, vnext[-1], r_coeff=1))
 
-    u_v = Eq(v.forward, model.damp * solve(eq_v, v.forward),
-             subdomain=model.physical)
-    u_t = Eq(tau.forward, model.damp * solve(eq_tau, tau.forward),
-             subdomain=model.physical)
+    u_t = Eq(taunext, solve(eq_tau, taunext), subdomain=model.physical)
 
     if model.fs:
-        fseq = freesurface(model, tau.forward.diagonal())
+        fseq = freesurface(model, taunext.diagonal())
     else:
         fseq = []
 
-    return [u_v, u_t], fseq
+    return [*u_v, u_t], fseq
diff --git a/src/pysource/models.py b/src/pysource/models.py
index 405fe8278..9bf958fa9 100644
--- a/src/pysource/models.py
+++ b/src/pysource/models.py
@@ -4,51 +4,61 @@
 
 from sympy import finite_diff_weights as fd_w
 from devito import (Grid, Function, SubDimension, Eq, Inc, switchconfig,
-                    Operator, mmin, mmax, initialize_function, MPI,
+                    Operator, mmin, mmax, initialize_function,
                     Abs, sqrt, sin, Constant, CustomDimension)
 
 from devito.tools import as_tuple, memoized_func
 
 try:
     from devitopro import *  # noqa
-    from devitopro.subdomains import ABox
+    from devitopro.subdomains.abox import ABox, ABoxFunction
+    from devitopro.data import Float16
     AboxBase = ABox
 except ImportError:
     ABox = None
     AboxBase = object
+    ABoxFunction = object
+    Float16 = lambda *ar, **kw: np.float32
 
 
-class ABoxSlowness(AboxBase):
+class SlownessABoxFunction(ABoxFunction):
 
-    def _1d_cmax(self, vp, eps):
-        cmaxs = []
+    _physical_params = ('m',)
 
+    def vp_max(self, rdim, **kwargs):
+        vp = kwargs.get('vp', self.vp)
+        eps = kwargs.get('eps', self.eps)
+
+        vpi = vp.data.min(axis=rdim)**(-.5)
+        # Thomsen correction
         if eps is not None:
-            assert vp.shape_allocated == eps.shape_allocated
-
-        for (di, d) in enumerate(vp.grid.dimensions):
-            rdim = tuple(i for (i, dl) in enumerate(vp.grid.dimensions) if dl is not d)
-
-            # Max over other dimensions, 1D array with the max in each plane
-            vpi = vp.data.min(axis=rdim)**(-.5)
-            # THomsen correction
-            if eps is not None:
-                epsi = eps.data.max(axis=rdim)
-                vpi._local[:] *= np.sqrt(1. + 2.*epsi._local[:])
-            # Gather on all ranks if distributed.
-            # Since we have a small-ish 1D vector we avoid the index gymnastic
-            # and create the full 1d vector on al ranks with the local values
-            # at the local indices and simply gather with Max
-            if vp.grid.distributor.is_parallel:
-                out = np.zeros(vp.grid.shape[di], dtype=vpi.dtype)
-                tmp = np.zeros(vp.grid.shape[di], dtype=vpi.dtype)
-                tmp[vp.local_indices[di]] = vpi._local
-                vp.grid.distributor.comm.Allreduce(tmp, out, op=MPI.MAX)
-                cmaxs.append(out)
-            else:
-                cmaxs.append(vpi)
+            epsi = eps.data.max(axis=rdim)
+            vpi._local[:] *= np.sqrt(1. + 2.*epsi._local[:])
+
+        return vpi
+
+
+class LameABoxFunction(ABoxFunction):
+
+    _physical_params = ('mu', 'lam', 'b')
+
+    def vp_max(self, rdim, **kwargs):
+        lam = kwargs.get('lam', self.lam)
+        mu = kwargs.get('mu', self.mu)
+        b = kwargs.get('b', self.b)
+
+        return np.sqrt(((lam.data + 2 * mu.data) * b.data).max(axis=rdim))
+
+
+class JUDIAbox(AboxBase):
+
+    def __init__(self, *args, subdomains=None, name=None, **params):
+        if 'mu' in params:
+            self._afunc = LameABoxFunction
+        else:
+            self._afunc = SlownessABoxFunction
 
-        return cmaxs
+        super().__init__(*args, subdomains=subdomains, name=name, **params)
 
 
 __all__ = ['Model']
@@ -189,6 +199,8 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         abc_type = "mask" if (qp is not None or mu is not None) else "damp"
         self.fs = fs
         self._abox = None
+        # Topology in case. Always decompose only in x or y
+        topo = tuple(['*']*(len(shape)-1) + [1])
         # Origin of the computational domain with boundary to inject/interpolate
         # at the correct index
         origin_pml = [dtype(o - s*nbl) for o, s in zip(origin, spacing)]
@@ -199,7 +211,7 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         # Physical extent is calculated per cell, so shape - 1
         extent = tuple(np.array(spacing) * (shape_pml - 1))
         self.grid = Grid(extent=extent, shape=shape_pml, origin=tuple(origin_pml),
-                         dtype=dtype)
+                         dtype=dtype, topology=topo)
 
         # Absorbing boundary layer
         if self.nbl != 0:
@@ -245,7 +257,8 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         # Additional parameter fields for elastic
         if self._is_elastic:
             self.lam = self._gen_phys_param(lam, 'lam', space_order, is_param=True)
-            self.mu = self._gen_phys_param(mu, 'mu', space_order, is_param=True)
+            self.mu = self._gen_phys_param(mu, 'mu', space_order, is_param=True,
+                                           avg_mode='harmonic')
         # User provided dt
         self._dt = kwargs.get('dt')
 
@@ -298,22 +311,26 @@ def zero_thomsen(self):
 
     @switchconfig(log_level='ERROR')
     def _gen_phys_param(self, field, name, space_order, is_param=False,
-                        default_value=0):
+                        default_value=0, avg_mode='arithmetic'):
         """
         Create symbolic object an initiliaze its data
         """
         if field is None:
             return default_value
         if isinstance(field, np.ndarray):
-            if field.shape == self.shape:
-                function = Function(name=name, grid=self.grid, space_order=space_order,
-                                    parameter=is_param)
-                initialize_function(function, field, self.padsizes)
+            if field.dtype == np.float16:
+                _min = np.amin(field)
+                _max = np.amax(field)
+                if _max == _min:
+                    _max = .125 if _min == 0 else _min * 1.125
+                dtype = Float16(_min, _max)
             else:
-                # We take advantage of the external allocator
-                function = Function(name=name, grid=self.grid, space_order=space_order,
-                                    parameter=is_param)
-                function.data[:] = field
+                dtype = self.grid.dtype
+
+            function = Function(name=name, grid=self.grid, space_order=space_order,
+                                parameter=is_param, avg_mode=avg_mode, dtype=dtype)
+            pad = self.padsizes if field.shape == self.shape else 0
+            initialize_function(function, field, pad)
         else:
             function = Constant(name=name, value=np.amin(field))
         self._physical_parameters.append(name)
@@ -326,10 +343,18 @@ def physical_parameters(self):
         """
         params = []
         for p in self._physical_parameters:
-            if getattr(self, p).is_Constant:
-                params.append('%s_const' % p)
+            param = getattr(self, p)
+            # Get dtype
+            comp = getattr(param, '_compression', param.dtype)
+            if isinstance(comp, Float16):
+                dtype = comp
+            else:
+                dtype = param.dtype
+            # Add to list
+            if param.is_Constant:
+                params.append(('%s_const' % p, dtype))
             else:
-                params.append(p)
+                params.append((p, dtype))
         return as_tuple(params)
 
     @property
@@ -550,9 +575,8 @@ def abox(self, src, rec, fw=True):
             return {}
         if not fw:
             src, rec = rec, src
-        eps = getattr(self, 'epsilon', None)
-        abox = ABoxSlowness(src, rec, self.m, self.space_order, eps=eps)
-        return {'abox': abox}
+        abox = JUDIAbox(self.space_order, src=src, rcv=rec, **self.physical_params())
+        return {'abox': abox._abox_func}
 
     def __init_abox__(self, src, rec, fw=True):
         return
@@ -572,7 +596,7 @@ def fs_dim(self):
                                symbolic_size=so)
 
 
-class EmptyModel():
+class EmptyModel(Model):
     """
     An pseudo Model structure that does not contain any physical field
     but only the necessary information to create an operator.
@@ -580,36 +604,38 @@ class EmptyModel():
     """
 
     def __init__(self, tti, visco, elastic, spacing, fs, space_order, p_params):
-        self.is_tti = tti
-        self.is_viscoacoustic = visco
-        self.is_elastic = elastic
-        self.spacing = spacing
+        self._is_tti = tti
+        self._is_viscoacoustic = visco
+        self._is_elastic = elastic
+        self._spacing = spacing
         self.fs = fs
-        self.space_order = space_order
+        self._space_order = space_order
         N = 2 * space_order + 1
 
         self.grid = Grid(tuple([N]*len(spacing)),
                          extent=[s*(N-1) for s in spacing])
         self.dimensions = self.grid.dimensions
 
+        # Make params a dict
+        p_params = {k: v for k, v in p_params if k != 'damp'}
         # Create the function for the physical parameters
         self.damp = Function(name='damp', grid=self.grid, space_order=0)
-        for p in set(p_params) - {'damp'}:
+        _physical_parameters = ['damp']
+        for p, dt in p_params.items():
             if p.endswith('_const'):
                 name = p.split('_')[0]
-                setattr(self, name, Constant(name=name, value=1))
+                setattr(self, name, Constant(name=name, value=1, dtype=dt))
             else:
-                setattr(self, p, Function(name=p, grid=self.grid,
-                                          space_order=space_order))
+                pn = '_%s' % p if p in ['m', 'dm'] else p
+                avgmode = 'harmonic' if p == 'mu' else 'arithmetic'
+                setattr(self, pn, Function(name=p, grid=self.grid, is_param=True,
+                                           space_order=space_order, dtype=dt,
+                                           avg_mode=avgmode))
+                _physical_parameters.append(p)
         if 'irho' not in p_params and 'irho_const' not in p_params:
             self.irho = 1 if 'rho' not in p_params else 1 / self.rho
 
-    @property
-    def spacing_map(self):
-        """
-        Map between spacing symbols and their values for each `SpaceDimension`.
-        """
-        return self.grid.spacing_map
+        self._physical_parameters = _physical_parameters
 
     @property
     def critical_dt(self):
@@ -618,13 +644,6 @@ def critical_dt(self):
         """
         return self.grid.time_dim.spacing
 
-    @property
-    def dim(self):
-        """
-        Spatial dimension of the problem and model domain.
-        """
-        return self.grid.dim
-
     @property
     def zero_thomsen(self):
         out = {}
@@ -641,21 +660,11 @@ def __init_abox__(self, src, rec, fw=True):
         if src is None and rec is None:
             self._abox = None
             return
-        eps = getattr(self, 'epsilon', None)
+
         if not fw:
             src, rec = rec, src
-        self._abox = ABoxSlowness(src, rec, self.m, self.space_order, eps=eps)
-
-    @cached_property
-    def physical(self):
-        if ABox is None:
-            return None
-        else:
-            return self._abox
-
-    @cached_property
-    def fs_dim(self):
-        so = self.space_order // 2
-        return CustomDimension(name="zfs", symbolic_min=1,
-                               symbolic_max=so,
-                               symbolic_size=so)
+        try:
+            self._abox = JUDIAbox(self.space_order, src=src, rcv=rec,
+                                  **self.physical_params())
+        except AttributeError:
+            return

From eec2c68e5f5b30d3043957501d39cb99bc029466 Mon Sep 17 00:00:00 2001
From: mloubout <mathias.louboutin@gmail.com>
Date: Mon, 30 Sep 2024 14:15:48 -0400
Subject: [PATCH 4/5] need vp

---
 .github/workflows/ci-judi.yml              |   2 +-
 .github/workflows/ci-op.yml                |   9 +-
 Project.toml                               |   1 +
 examples/scripts/modeling_basic_elastic.jl |   6 +-
 src/pysource/geom_utils.py                 |  27 +--
 src/pysource/models.py                     | 184 ++++++---------------
 src/pysource/operators.py                  |  18 +-
 src/pysource/propagators.py                |   4 +-
 8 files changed, 90 insertions(+), 161 deletions(-)

diff --git a/.github/workflows/ci-judi.yml b/.github/workflows/ci-judi.yml
index 5bf787016..4aac77fb7 100644
--- a/.github/workflows/ci-judi.yml
+++ b/.github/workflows/ci-judi.yml
@@ -19,7 +19,7 @@ jobs:
     name: JUDI base on Julia ${{ matrix.version }} - ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
     env:
-      DEVITO_ARCH: gcc-11
+      DEVITO_ARCH: gcc-12
       DEVITO_LANGUAGE: "openmp"
       OMP_NUM_THREADS: 4
       GROUP: "JUDI"
diff --git a/.github/workflows/ci-op.yml b/.github/workflows/ci-op.yml
index d5beda4f2..ed73bb54f 100644
--- a/.github/workflows/ci-op.yml
+++ b/.github/workflows/ci-op.yml
@@ -19,7 +19,7 @@ jobs:
     name: ${{ matrix.op }} on Julia ${{ matrix.version }} - ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
     env:
-      DEVITO_ARCH: gcc-11
+      DEVITO_ARCH: ${{ matrix.cc }}
       DEVITO_LANGUAGE: "openmp"
       DEVITO_LOGGING: "INFO"
       OMP_NUM_THREADS: ${{ matrix.omp }}
@@ -33,32 +33,38 @@ jobs:
         op: ["ISO_OP", "ISO_OP_FS", "TTI_OP", "TTI_OP_FS", "BASICS"]
         version: ['1']
         omp: [2]
+        cc: ['gcc-11']
   
         include:
           - os: macos-13
             version: '1.6'
             op: "ISO_OP"
             omp: 1
+            cc: gcc-13
 
           - os: macos-13
             version: '1.8'
             op: "ISO_OP"
             omp: 1
+            cc: gcc-13
 
           - os: macos-13
             version: '1.9'
             op: "ISO_OP"
             omp: 1
+            cc: gcc-13
 
           - os: ubuntu-latest
             version: '1.9'
             op: "VISCO_AC_OP"
             omp: 2
+            cc: gcc-11
 
           - os: ubuntu-latest
             version: '1.10'
             op: "ISO_OP"
             omp: 2
+            cc: gcc-11
 
     steps:
       - name: Checkout JUDI
@@ -78,6 +84,7 @@ jobs:
       - name: Set julia python
         run: |
           echo "PYTHON=$(which python3)" >> $GITHUB_ENV
+          python3 -m pip install devito[tests,extras]@git+https://github.com/devitocodes/devito.git
           PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
 
       - name: Build JUDI
diff --git a/Project.toml b/Project.toml
index 06e036797..d239b0046 100644
--- a/Project.toml
+++ b/Project.toml
@@ -12,6 +12,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
 JOLI = "bb331ad6-a1cf-11e9-23da-9bcb53c69f6f"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
diff --git a/examples/scripts/modeling_basic_elastic.jl b/examples/scripts/modeling_basic_elastic.jl
index 454e67561..c252f9606 100644
--- a/examples/scripts/modeling_basic_elastic.jl
+++ b/examples/scripts/modeling_basic_elastic.jl
@@ -29,7 +29,7 @@ nxrec = 120
 nyrec = 100
 xrec = range(50f0, stop=1150f0, length=nxrec)
 yrec = 0f0
-zrec = range(0f0, stop=0f0, length=nxrec)
+zrec = range(10f0, stop=10f0, length=nxrec)
 
 # receiver sampling and recording time
 timeR = 1500f0   # receiver recording time [ms]
@@ -41,7 +41,7 @@ recGeometry = Geometry(xrec, yrec, zrec; dt=dtR, t=timeR, nsrc=nsrc)
 # Set up source geometry (cell array with source locations for each shot)
 xsrc = 600f0
 ysrc = 0f0
-zsrc = 0f0
+zsrc = 10f0
 
 # source sampling and number of time steps
 timeS = 1500f0   # source length in [ms]
@@ -57,7 +57,7 @@ wavelet = ricker_wavelet(timeS, dtS, f0)
 ###################################################################################################
 
 # Setup operators
-F = judiModeling(model, srcGeometry, recGeometry)
+F = judiModeling(model, srcGeometry, recGeometry; options=Options(space_order=8, free_surface=true))
 q = judiVector(srcGeometry, wavelet)
 
 # Nonlinear modeling
diff --git a/src/pysource/geom_utils.py b/src/pysource/geom_utils.py
index 5f8e49c06..0dccc325b 100644
--- a/src/pysource/geom_utils.py
+++ b/src/pysource/geom_utils.py
@@ -7,7 +7,7 @@
 try:
     from recipes.utils import mirror_source
 except ImportError:
-    def mirror_source(src):
+    def mirror_source(model, src):
         return src
 
 
@@ -20,12 +20,12 @@ def src_rec(model, u, src_coords=None, rec_coords=None, wavelet=None, nt=None):
             src = wavelet
         else:
             src = PointSource(name="src%s" % namef, grid=model.grid, ntime=nt,
-                              coordinates=src_coords, interpolation='sinc', r=3)
+                              coordinates=src_coords)
             src.data[:] = wavelet.view(np.ndarray) if wavelet is not None else 0.
     rcv = None
     if rec_coords is not None:
         rcv = Receiver(name="rcv%s" % namef, grid=model.grid, ntime=nt,
-                       coordinates=rec_coords, interpolation='sinc', r=3)
+                       coordinates=rec_coords)
     return src, rcv
 
 
@@ -62,14 +62,20 @@ def geom_expr(model, u, src_coords=None, rec_coords=None, wavelet=None, fw=True,
     if not model.is_elastic:
         m = model.m * irho
     dt = model.grid.time_dim.spacing
-    geom_expr = []
     src, rcv = src_rec(model, u, src_coords, rec_coords, wavelet, nt)
     model.__init_abox__(src, rcv, fw=fw)
+
+    geom_expr = []
+    # Source
     if src is not None:
         # Elastic inject into diagonal of stress
         if model.is_elastic:
-            for ud in as_tuple(u)[1].diagonal():
-                geom_expr += src.inject(field=ud.forward, expr=src*dt/irho)
+            c = 1 / model.grid.dim
+            src_eq = src.inject(field=as_tuple(u)[1].forward.diagonal(),
+                                expr=c*src*dt/irho)
+            if model.fs:
+                # Free surface
+                src_eq = mirror_source(model, src_eq)
         else:
             # Acoustic inject into pressure
             u_n = as_tuple(u)[0].forward if fw else as_tuple(u)[0].backward
@@ -77,13 +83,14 @@ def geom_expr(model, u, src_coords=None, rec_coords=None, wavelet=None, fw=True,
             if model.fs:
                 # Free surface
                 src_eq = mirror_source(model, src_eq)
-            geom_expr += src_eq
+
+        geom_expr += [src_eq]
     # Setup adjoint wavefield sampling at source locations
     if rcv is not None:
         if model.is_elastic:
-            rec_expr = u[1].trace()
+            geom_expr = u[1].trace() / model.grid.dim
         else:
             rec_expr = u[0] if model.is_tti else u
-        adj_rcv = rcv.interpolate(expr=rec_expr)
-        geom_expr += adj_rcv
+        geom_expr += rcv.interpolate(expr=rec_expr)
+
     return geom_expr
diff --git a/src/pysource/models.py b/src/pysource/models.py
index 9bf958fa9..81f3f069b 100644
--- a/src/pysource/models.py
+++ b/src/pysource/models.py
@@ -9,56 +9,18 @@
 
 from devito.tools import as_tuple, memoized_func
 
-try:
-    from devitopro import *  # noqa
-    from devitopro.subdomains.abox import ABox, ABoxFunction
-    from devitopro.data import Float16
-    AboxBase = ABox
-except ImportError:
-    ABox = None
-    AboxBase = object
-    ABoxFunction = object
-    Float16 = lambda *ar, **kw: np.float32
+# try:
+#    from devitopro import *  # noqa
+#    from devitopro.subdomains.abox import ABox, ABoxFunction
+#    from devitopro.data import Float16
+#    AboxBase = ABox
+# except ImportError:
+ABox = None
+AboxBase = object
+ABoxFunction = object
+Float16 = lambda *ar, **kw: np.float32
 
-
-class SlownessABoxFunction(ABoxFunction):
-
-    _physical_params = ('m',)
-
-    def vp_max(self, rdim, **kwargs):
-        vp = kwargs.get('vp', self.vp)
-        eps = kwargs.get('eps', self.eps)
-
-        vpi = vp.data.min(axis=rdim)**(-.5)
-        # Thomsen correction
-        if eps is not None:
-            epsi = eps.data.max(axis=rdim)
-            vpi._local[:] *= np.sqrt(1. + 2.*epsi._local[:])
-
-        return vpi
-
-
-class LameABoxFunction(ABoxFunction):
-
-    _physical_params = ('mu', 'lam', 'b')
-
-    def vp_max(self, rdim, **kwargs):
-        lam = kwargs.get('lam', self.lam)
-        mu = kwargs.get('mu', self.mu)
-        b = kwargs.get('b', self.b)
-
-        return np.sqrt(((lam.data + 2 * mu.data) * b.data).max(axis=rdim))
-
-
-class JUDIAbox(AboxBase):
-
-    def __init__(self, *args, subdomains=None, name=None, **params):
-        if 'mu' in params:
-            self._afunc = LameABoxFunction
-        else:
-            self._afunc = SlownessABoxFunction
-
-        super().__init__(*args, subdomains=subdomains, name=name, **params)
+_dtypes = {'params': 'f32'}
 
 
 __all__ = ['Model']
@@ -196,7 +158,7 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         self.shape = tuple(shape)
         self.nbl = int(nbl)
         self.origin = tuple([dtype(o) for o in origin])
-        abc_type = "mask" if (qp is not None or mu is not None) else "damp"
+        abc_type = "mask" if qp is not None else "damp"
         self.fs = fs
         self._abox = None
         # Topology in case. Always decompose only in x or y
@@ -226,13 +188,17 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         # Seismic fields and properties
         self.scale = 1
         self._space_order = space_order
+
         # Create square slowness of the wave as symbol `m`
         if m is not None:
-            self._m = self._gen_phys_param(m, 'm', space_order)
+            vp_vals = m**(-.5)
+            self.m = self._gen_phys_param(m, 'm', space_order)
+
         # density
         self._init_density(rho, b, space_order)
+
         # Perturbation for linearized modeling
-        self._dm = self._gen_phys_param(dm, 'dm', space_order)
+        self.dm = self._gen_phys_param(dm, 'dm', space_order)
 
         # Model type
         self._is_viscoacoustic = qp is not None
@@ -257,11 +223,19 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         # Additional parameter fields for elastic
         if self._is_elastic:
             self.lam = self._gen_phys_param(lam, 'lam', space_order, is_param=True)
+            mu[np.where(mu == 0)] = 1e-12
             self.mu = self._gen_phys_param(mu, 'mu', space_order, is_param=True,
                                            avg_mode='harmonic')
+            b = b if b is not None else 1 / rho
+            vp_vals = ((lam + 2 * mu) * b)**(.5)
+
         # User provided dt
         self._dt = kwargs.get('dt')
 
+        # Need vp for Abox
+        if ABox is not None:
+            self._vp = self._gen_phys_param(vp_vals, '_vp', space_order)
+
     def _init_density(self, rho, b, so):
         """
         Initialize density parameter. Depending on variance in density
@@ -297,6 +271,9 @@ def physical_params(self, **kwargs):
         if not kwargs.get('born', False):
             params.pop('dm', None)
 
+        # Remove "build" _x params
+        params.pop('_vp', None)
+
         return params
 
     @property
@@ -313,12 +290,12 @@ def zero_thomsen(self):
     def _gen_phys_param(self, field, name, space_order, is_param=False,
                         default_value=0, avg_mode='arithmetic'):
         """
-        Create symbolic object an initiliaze its data
+        Create symbolic object an initialize its data
         """
         if field is None:
             return default_value
         if isinstance(field, np.ndarray):
-            if field.dtype == np.float16:
+            if _dtypes['params'] == 'f16' or field.dtype == np.float16:
                 _min = np.amin(field)
                 _max = np.amax(field)
                 if _max == _min:
@@ -329,7 +306,7 @@ def _gen_phys_param(self, field, name, space_order, is_param=False,
 
             function = Function(name=name, grid=self.grid, space_order=space_order,
                                 parameter=is_param, avg_mode=avg_mode, dtype=dtype)
-            pad = self.padsizes if field.shape == self.shape else 0
+            pad = 0 if field.shape == function.shape else self.padsizes
             initialize_function(function, field, pad)
         else:
             function = Constant(name=name, value=np.amin(field))
@@ -346,9 +323,12 @@ def physical_parameters(self):
             param = getattr(self, p)
             # Get dtype
             comp = getattr(param, '_compression', param.dtype)
-            if isinstance(comp, Float16):
-                dtype = comp
-            else:
+            try:
+                if isinstance(comp, Float16):
+                    dtype = comp
+                else:
+                    dtype = param.dtype
+            except TypeError:
                 dtype = param.dtype
             # Add to list
             if param.is_Constant:
@@ -457,7 +437,7 @@ def _cfl_coeff(self):
             so = max(self.space_order // 2, 2)
             coeffs = fd_w(1, range(-so, so), .5)
             c_fd = sum(np.abs(coeffs[-1][-1])) / 2
-            return .9 * np.sqrt(self.dim) / self.dim / c_fd
+            return .95 * np.sqrt(self.dim) / self.dim / c_fd
         a1 = 4  # 2nd order in time
         so = max(self.space_order // 2, 4)
         coeffs = fd_w(2, range(-so, so), 0)[-1][-1]
@@ -489,71 +469,6 @@ def critical_dt(self):
                 return self.dtype("%.3e" % self.dt)
         return dt
 
-    @property
-    def dm(self):
-        """
-        Model perturbation for linearized modeling
-        """
-        return self._dm
-
-    @dm.setter
-    def dm(self, dm):
-        """
-        Set a new model perturbation.
-
-        Parameters
-        ----------
-        dm : float or array
-            New model perturbation
-        """
-        # Update the square slowness according to new value
-        if isinstance(dm, np.ndarray):
-            if not isinstance(self._dm, Function):
-                self._dm = self._gen_phys_param(dm, 'dm', self.space_order)
-            elif dm.shape == self.shape:
-                initialize_function(self._dm, dm, self.padsizes)
-            elif dm.shape == self.dm.shape:
-                self.dm.data[:] = dm[:]
-            else:
-                raise ValueError("Incorrect input size %s for model of size" % dm.shape +
-                                 " %s without or %s with padding" % (self.shape,
-                                                                     self.dm.shape))
-        else:
-            try:
-                self._dm.data = dm
-            except AttributeError:
-                self._dm = dm
-
-    @property
-    def m(self):
-        """
-        Function holding the squared slowness in s^2/km^2.
-        """
-        return self._m
-
-    @m.setter
-    def m(self, m):
-        """
-        Set a new squared slowness model.
-
-        Parameters
-        ----------
-        m : float or array
-            New squared slowness in s^2/km^2.
-        """
-        # Update the square slowness according to new value
-        if isinstance(m, np.ndarray):
-            if m.shape == self.m.shape:
-                self.m.data[:] = m[:]
-            elif m.shape == self.shape:
-                initialize_function(self._m, m, self.padsizes)
-            else:
-                raise ValueError("Incorrect input size %s for model of size" % m.shape +
-                                 " %s without or %s with padding" % (self.shape,
-                                                                     self.m.shape))
-        else:
-            self._m.data = m
-
     @property
     def vp(self):
         """
@@ -575,8 +490,9 @@ def abox(self, src, rec, fw=True):
             return {}
         if not fw:
             src, rec = rec, src
-        abox = JUDIAbox(self.space_order, src=src, rcv=rec, **self.physical_params())
-        return {'abox': abox._abox_func}
+        eps = getattr(self, 'epsilon', None)
+        abox = ABox(src, rec, self._vp, self.space_order, eps=eps)
+        return {'abox': abox}
 
     def __init_abox__(self, src, rec, fw=True):
         return
@@ -622,15 +538,16 @@ def __init__(self, tti, visco, elastic, spacing, fs, space_order, p_params):
         self.damp = Function(name='damp', grid=self.grid, space_order=0)
         _physical_parameters = ['damp']
         for p, dt in p_params.items():
+            if _dtypes['params'] == 'f16':
+                dt = np.float16
             if p.endswith('_const'):
                 name = p.split('_')[0]
                 setattr(self, name, Constant(name=name, value=1, dtype=dt))
             else:
-                pn = '_%s' % p if p in ['m', 'dm'] else p
                 avgmode = 'harmonic' if p == 'mu' else 'arithmetic'
-                setattr(self, pn, Function(name=p, grid=self.grid, is_param=True,
-                                           space_order=space_order, dtype=dt,
-                                           avg_mode=avgmode))
+                setattr(self, p, Function(name=p, grid=self.grid, is_param=True,
+                                          space_order=space_order, dtype=dt,
+                                          avg_mode=avgmode))
                 _physical_parameters.append(p)
         if 'irho' not in p_params and 'irho_const' not in p_params:
             self.irho = 1 if 'rho' not in p_params else 1 / self.rho
@@ -663,8 +580,5 @@ def __init_abox__(self, src, rec, fw=True):
 
         if not fw:
             src, rec = rec, src
-        try:
-            self._abox = JUDIAbox(self.space_order, src=src, rcv=rec,
-                                  **self.physical_params())
-        except AttributeError:
-            return
+        eps = getattr(self, 'epsilon', None)
+        self._abox = ABox(src, rec, self._vp, self.space_order, eps=eps)
diff --git a/src/pysource/operators.py b/src/pysource/operators.py
index 16b464b43..cd99036ca 100644
--- a/src/pysource/operators.py
+++ b/src/pysource/operators.py
@@ -99,8 +99,8 @@ def forward_op(p_params, tti, visco, elas, space_order, fw, spacing, save, t_sub
     u = wavefield(model, space_order, save=save, nt=nt, t_sub=t_sub, fw=fw)
 
     # Setup source and receiver
-    g_expr = geom_expr(model, u, src_coords=scords, nt=nt,
-                       rec_coords=rcords, wavelet=wavelet, fw=fw)
+    gexpr = geom_expr(model, u, src_coords=scords, nt=nt,
+                      rec_coords=rcords, wavelet=wavelet, fw=fw)
 
     # Expression for saving wavefield if time subsampling is used
     eq_save = save_subsampled(model, u, nt, t_sub, space_order=space_order)
@@ -126,7 +126,7 @@ def forward_op(p_params, tti, visco, elas, space_order, fw, spacing, save, t_sub
     # Create operator and run
     subs = model.spacing_map
     pname = "forward" if fw else "adjoint"
-    op = Operator(pde + wrec + nv_t + dft + g_expr + extra + eq_save + nv_s + Ieq,
+    op = Operator(pde + wrec + nv_t + dft + gexpr + extra + eq_save + nv_s + Ieq,
                   subs=subs, name=pname+name(model),
                   opt=opt_op(model))
     op.cfunction
@@ -158,9 +158,9 @@ def born_op(p_params, tti, visco, elas, space_order, fw, spacing, save, pt_src,
     ul = wavefield(model, space_order, name="l", fw=fw)
 
     # Setup source and receiver
-    g_expr = geom_expr(model, u, rec_coords=rcords if nlind else None,
-                       src_coords=scords, wavelet=wavelet, fw=fw)
-    g_exprl = geom_expr(model, ul, rec_coords=rcords, nt=nt, fw=fw)
+    gexpr = geom_expr(model, u, rec_coords=rcords if nlind else None,
+                      src_coords=scords, wavelet=wavelet, fw=fw)
+    gexprl = geom_expr(model, ul, rec_coords=rcords, nt=nt, fw=fw)
 
     # Expression for saving wavefield if time subsampling is used
     eq_save = save_subsampled(model, u, nt, t_sub, space_order=space_order)
@@ -183,7 +183,7 @@ def born_op(p_params, tti, visco, elas, space_order, fw, spacing, save, pt_src,
 
     # Create operator and run
     subs = model.spacing_map
-    op = Operator(pde + g_expr + extra + g_exprl + pdel + extral + dft + eq_save + Ieq,
+    op = Operator(pde + gexpr + extra + gexprl + pdel + extral + dft + eq_save + Ieq,
                   subs=subs, name="born"+name(model),
                   opt=opt_op(model))
     op.cfunction
@@ -211,7 +211,7 @@ def adjoint_born_op(p_params, tti, visco, elas, space_order, fw, spacing, pt_rec
                           dft=nfreq > 0, t_sub=t_sub, fw=fw)
 
     # Setup source and receiver
-    r_expr = geom_expr(model, v, src_coords=rcords, wavelet=residual, fw=not fw)
+    gexpr = geom_expr(model, v, src_coords=rcords, wavelet=residual, fw=not fw)
 
     # Set up PDE expression and rearrange
     pde, extra = wave_kernel(model, v, fw=False, f0=Constant('f0'))
@@ -226,7 +226,7 @@ def adjoint_born_op(p_params, tti, visco, elas, space_order, fw, spacing, pt_rec
 
     # Create operator and run
     subs = model.spacing_map
-    op = Operator(pde + r_expr + extra + g_expr + Ieq, subs=subs,
+    op = Operator(pde + gexpr + extra + g_expr + Ieq, subs=subs,
                   name="gradient"+name(model),
                   opt=opt_op(model))
     try:
diff --git a/src/pysource/propagators.py b/src/pysource/propagators.py
index 1cdea79ff..c6ebc1978 100644
--- a/src/pysource/propagators.py
+++ b/src/pysource/propagators.py
@@ -253,7 +253,7 @@ def forward_grad(model, src_coords, rcv_coords, wavelet, v,
     pde, extra = wave_kernel(model, u, q=q, f0=f0)
 
     # Setup source and receiver
-    rexpr = geom_expr(model, u, src_coords=src_coords, nt=nt,
+    gexpr = geom_expr(model, u, src_coords=src_coords, nt=nt,
                       rec_coords=rcv_coords, wavelet=wavelet)
     _, rcv = src_rec(model, u, src_coords, rcv_coords, wavelet, nt)
 
@@ -263,7 +263,7 @@ def forward_grad(model, src_coords, rcv_coords, wavelet, v,
 
     # Create operator and run
     subs = model.spacing_map
-    op = Operator(pde + rexpr + extra + g_expr,
+    op = Operator(pde + gexpr + extra + g_expr,
                   subs=subs, name="forward_grad",
                   opt=opt_op(model))
 

From e5f497a8e4b5dc3a1f70ac18b2a87cf261c29093 Mon Sep 17 00:00:00 2001
From: mloubout <mathias.louboutin@gmail.com>
Date: Mon, 30 Sep 2024 21:09:29 -0400
Subject: [PATCH 5/5] switch to arm macos runner

---
 .github/workflows/ci-examples.yml |  2 +-
 .github/workflows/ci-judi.yml     | 31 +++++++++++++++++++------
 .github/workflows/ci-op.yml       | 38 +++++++++++++++++++++----------
 .github/workflows/deploy_doc.yaml |  2 +-
 Project.toml                      |  2 +-
 deps/build.jl                     |  6 ++---
 src/pysource/geom_utils.py        |  2 +-
 src/pysource/models.py            | 33 +++++++++++++++------------
 src/pysource/utils.py             |  2 +-
 9 files changed, 77 insertions(+), 41 deletions(-)

diff --git a/.github/workflows/ci-examples.yml b/.github/workflows/ci-examples.yml
index 13959c4bc..8be8dce20 100644
--- a/.github/workflows/ci-examples.yml
+++ b/.github/workflows/ci-examples.yml
@@ -55,7 +55,7 @@ jobs:
         uses: actions/checkout@v4
 
       - name: Setup julia
-        uses: julia-actions/setup-julia@v1
+        uses: julia-actions/setup-julia@v2
         with:
           version: ${{ matrix.version }}
           arch: x64
diff --git a/.github/workflows/ci-judi.yml b/.github/workflows/ci-judi.yml
index 4aac77fb7..52c33fe94 100644
--- a/.github/workflows/ci-judi.yml
+++ b/.github/workflows/ci-judi.yml
@@ -19,7 +19,7 @@ jobs:
     name: JUDI base on Julia ${{ matrix.version }} - ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
     env:
-      DEVITO_ARCH: gcc-12
+      DEVITO_ARCH: ${{ matrix.cc }}
       DEVITO_LANGUAGE: "openmp"
       OMP_NUM_THREADS: 4
       GROUP: "JUDI"
@@ -28,28 +28,44 @@ jobs:
       fail-fast: false
 
       matrix:
-        version: ['1.6', '1.7', '1.8', '1.9', '1.10']
-        os: [ubuntu-latest, macos-13]
+        version: ['lts', '1.7', '1.8', '1.9', '1.10']
+        os: [ubuntu-latest]
+        arch: ['x64']
+        cc: ['gcc-12']
+
+        include:
+          - os: macos-15
+            version: '1'
+            arch: ARM64
+            cc: clang
 
     steps:
       - name: Checkout JUDI
         uses: actions/checkout@v4
 
       - name: Setup julia
-        uses: julia-actions/setup-julia@v1
+        uses: julia-actions/setup-julia@v2
         with:
           version: ${{ matrix.version }}
-          arch: x64
+          arch: ${{ matrix.arch }}
+
+      - name: Setup clang for osx
+        if: runner.os == 'macOS'
+        run: |
+          brew install llvm libomp
+          echo "/opt/homebrew/bin:/opt/homebrew/opt/llvm/bin" >> $GITHUB_PATH
 
       - name: Set up Python 3.9
-        uses: actions/setup-python@v4
+        uses: actions/setup-python@v5
         with:
           python-version: 3.9
 
       - name: Set julia python
         run: |
           echo "PYTHON=$(which python3)" >> $GITHUB_ENV
-          PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
+          echo "PYCALL_JL_RUNTIME_PYTHON=$(which python3)" >> $GITHUB_ENV
+          python3 -m pip install devito[tests,extras]@git+https://github.com/devitocodes/devito.git
+          PYCALL_JL_RUNTIME_PYTHON=$(which python3) PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
 
       - name: Build JUDI
         uses: julia-actions/julia-buildpkg@latest
@@ -60,6 +76,7 @@ jobs:
           annotate: true
 
       - uses: julia-actions/julia-processcoverage@v1
+
       - uses: codecov/codecov-action@v4
         with:
           file: lcov.info
diff --git a/.github/workflows/ci-op.yml b/.github/workflows/ci-op.yml
index ed73bb54f..d9d07729e 100644
--- a/.github/workflows/ci-op.yml
+++ b/.github/workflows/ci-op.yml
@@ -34,58 +34,71 @@ jobs:
         version: ['1']
         omp: [2]
         cc: ['gcc-11']
+        arch: ['x64']
   
         include:
-          - os: macos-13
-            version: '1.6'
+          - os: macos-15
+            version: '1'
             op: "ISO_OP"
             omp: 1
-            cc: gcc-13
+            cc: clang
+            arch: ARM64
 
-          - os: macos-13
+          - os: macos-15
             version: '1.8'
             op: "ISO_OP"
             omp: 1
-            cc: gcc-13
+            cc: clang
+            arch: ARM64
 
-          - os: macos-13
+          - os: macos-15
             version: '1.9'
             op: "ISO_OP"
             omp: 1
-            cc: gcc-13
+            cc: clang
+            arch: ARM64
 
           - os: ubuntu-latest
             version: '1.9'
             op: "VISCO_AC_OP"
             omp: 2
             cc: gcc-11
+            arch: x64
 
           - os: ubuntu-latest
             version: '1.10'
             op: "ISO_OP"
             omp: 2
             cc: gcc-11
+            arch: x64
 
     steps:
       - name: Checkout JUDI
         uses: actions/checkout@v4
 
       - name: Setup julia
-        uses: julia-actions/setup-julia@v1
+        uses: julia-actions/setup-julia@v2
         with:
           version: ${{ matrix.version }}
-          arch: x64
+          arch: ${{ matrix.arch }}
+
+      - name: Setup clang for osx
+        if: runner.os == 'macOS'
+        run: |
+          brew install llvm libomp
+          echo "/opt/homebrew/bin:/opt/homebrew/opt/llvm/bin" >> $GITHUB_PATH
 
       - name: Set up Python 3.9
-        uses: actions/setup-python@v4
+        uses: actions/setup-python@v5
         with:
           python-version: 3.9
 
       - name: Set julia python
         run: |
           echo "PYTHON=$(which python3)" >> $GITHUB_ENV
+          echo "PYCALL_JL_RUNTIME_PYTHON=$(which python3)" >> $GITHUB_ENV
           python3 -m pip install devito[tests,extras]@git+https://github.com/devitocodes/devito.git
-          PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
+          PYCALL_JL_RUNTIME_PYTHON=$(which python3) PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
 
       - name: Build JUDI
         uses: julia-actions/julia-buildpkg@latest
@@ -96,7 +109,8 @@ jobs:
           annotate: true
 
       - uses: julia-actions/julia-processcoverage@v1
+
       - uses: codecov/codecov-action@v4
         with:
           file: lcov.info
-          token: ${{ secrets.CODECOV_TOKEN }}
\ No newline at end of file
+          token: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.github/workflows/deploy_doc.yaml b/.github/workflows/deploy_doc.yaml
index 6d82cb509..748b2c790 100644
--- a/.github/workflows/deploy_doc.yaml
+++ b/.github/workflows/deploy_doc.yaml
@@ -24,7 +24,7 @@ jobs:
       - uses: julia-actions/setup-julia@latest
       
       - name: Set up Python 3.9
-        uses: actions/setup-python@v4
+        uses: actions/setup-python@v5
         with:
           python-version: 3.9
 
diff --git a/Project.toml b/Project.toml
index d239b0046..f6e01808d 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "JUDI"
 uuid = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
 authors = ["Philipp Witte, Mathias Louboutin"]
-version = "3.4.6"
+version = "3.4.7"
 
 [deps]
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
diff --git a/deps/build.jl b/deps/build.jl
index 3a357ac6c..cc89bdb97 100644
--- a/deps/build.jl
+++ b/deps/build.jl
@@ -8,14 +8,14 @@ end
 pk = try
     pyimport("pkg_resources")
 catch e
-    run(PyCall.python_cmd(`-m pip install --user setuptools`))
+    run(PyCall.python_cmd(`-m pip install -U --user --no-cache-dir setuptools`))
     pyimport("pkg_resources")
 end
 
 ################## Devito ##################
 # pip command
 dvver = "4.8.10"
-cmd = PyCall.python_cmd(`-m pip install --user devito\[extras,tests\]\>\=$(dvver)`)
+cmd = PyCall.python_cmd(`-m pip install --user --no-cache-dir devito\[extras,tests\]\>\=$(dvver)`)
 
 try
     dv_ver = VersionNumber(split(pk.get_distribution("devito").version, "+")[1])
@@ -32,5 +32,5 @@ end
 try
     mpl = pyimport("matplotlib")
 catch e
-    run(PyCall.python_cmd(`-m pip install --user matplotlib`))
+    run(PyCall.python_cmd(`-m pip install --user --no-cache-dir matplotlib`))
 end
diff --git a/src/pysource/geom_utils.py b/src/pysource/geom_utils.py
index 0dccc325b..8e60d0595 100644
--- a/src/pysource/geom_utils.py
+++ b/src/pysource/geom_utils.py
@@ -88,7 +88,7 @@ def geom_expr(model, u, src_coords=None, rec_coords=None, wavelet=None, fw=True,
     # Setup adjoint wavefield sampling at source locations
     if rcv is not None:
         if model.is_elastic:
-            geom_expr = u[1].trace() / model.grid.dim
+            rec_expr = u[1].trace() / model.grid.dim
         else:
             rec_expr = u[0] if model.is_tti else u
         geom_expr += rcv.interpolate(expr=rec_expr)
diff --git a/src/pysource/models.py b/src/pysource/models.py
index 81f3f069b..f01960fc6 100644
--- a/src/pysource/models.py
+++ b/src/pysource/models.py
@@ -9,16 +9,13 @@
 
 from devito.tools import as_tuple, memoized_func
 
-# try:
-#    from devitopro import *  # noqa
-#    from devitopro.subdomains.abox import ABox, ABoxFunction
-#    from devitopro.data import Float16
-#    AboxBase = ABox
-# except ImportError:
-ABox = None
-AboxBase = object
-ABoxFunction = object
-Float16 = lambda *ar, **kw: np.float32
+try:
+    from devitopro import *  # noqa
+    from devitopro.subdomains.abox import ABox
+    from devitopro.data import Float16
+except ImportError:
+    ABox = None
+    Float16 = lambda *ar, **kw: np.float32
 
 _dtypes = {'params': 'f32'}
 
@@ -223,10 +220,16 @@ def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float
         # Additional parameter fields for elastic
         if self._is_elastic:
             self.lam = self._gen_phys_param(lam, 'lam', space_order, is_param=True)
-            mu[np.where(mu == 0)] = 1e-12
+            try:
+                mu[np.where(mu == 0)] = 1e-12
+            except TypeError:
+                mu = 1e-12 if mu == 0 else mu
             self.mu = self._gen_phys_param(mu, 'mu', space_order, is_param=True,
                                            avg_mode='harmonic')
-            b = b if b is not None else 1 / rho
+            try:
+                b = b if b is not None else 1 / rho
+            except TypeError:
+                b = 1
             vp_vals = ((lam + 2 * mu) * b)**(.5)
 
         # User provided dt
@@ -437,7 +440,7 @@ def _cfl_coeff(self):
             so = max(self.space_order // 2, 2)
             coeffs = fd_w(1, range(-so, so), .5)
             c_fd = sum(np.abs(coeffs[-1][-1])) / 2
-            return .95 * np.sqrt(self.dim) / self.dim / c_fd
+            return .9 * np.sqrt(self.dim) / self.dim / c_fd
         a1 = 4  # 2nd order in time
         so = max(self.space_order // 2, 4)
         coeffs = fd_w(2, range(-so, so), 0)[-1][-1]
@@ -491,7 +494,9 @@ def abox(self, src, rec, fw=True):
         if not fw:
             src, rec = rec, src
         eps = getattr(self, 'epsilon', None)
-        abox = ABox(src, rec, self._vp, self.space_order, eps=eps)
+        abox = ABox(src, rec, self._vp, self.space_order, eps=eps)._abox_func
+        abox.data[:] = abox._compute(src=src, rcv=rec, vp=self._vp, eps=eps,
+                                     dt=self.critical_dt)
         return {'abox': abox}
 
     def __init_abox__(self, src, rec, fw=True):
diff --git a/src/pysource/utils.py b/src/pysource/utils.py
index 7892d0d2f..01ddf39be 100644
--- a/src/pysource/utils.py
+++ b/src/pysource/utils.py
@@ -156,7 +156,7 @@ def compression_mode():
     """
     Check compiler used to see if can use bitcomp
     """
-    if configuration['compiler'] in [NvidiaCompiler, CudaCompiler]:
+    if isinstance(configuration['compiler'], (NvidiaCompiler, CudaCompiler)):
         return 'bitcomp'
     else:
         return 'noop'