Skip to content

Commit

Permalink
Merge pull request #98 from slimgroup/time-tests
Browse files Browse the repository at this point in the history
Improve memory overhead and custom types
  • Loading branch information
mloubout authored Mar 15, 2022
2 parents f65e339 + 4089fbb commit a2eb2c8
Show file tree
Hide file tree
Showing 36 changed files with 1,609 additions and 1,539 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ PyCall = "1.18, 1.90, 1.91, 1.62"
Reexport = "0.2, 1"
SegyIO = "0.7.7"
julia = "1"
TimerOutputs = "0.5"

This comment has been minimized.

Copy link
@mloubout

mloubout Mar 15, 2022

Author Member

[extras]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[targets]
test = ["Test", "ArgParse", "Printf", "JLD2"]
test = ["Test", "ArgParse", "Printf", "JLD2", "TimerOutputs"]
2 changes: 1 addition & 1 deletion src/JUDI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using JOLI, SegyIO

# Import Base functions to dispatch on JUDI types
import Base.*, Base./, Base.+, Base.-
import Base.copy!, Base.copy
import Base.copy!, Base.copy, Base.convert
import Base.sum, Base.ndims, Base.reshape, Base.fill!, Base.axes, Base.dotview
import Base.eltype, Base.length, Base.size, Base.iterate, Base.show, Base.display, Base.showarg
import Base.maximum, Base.minimum, Base.push!
Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/Modeling/distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
Filter input arguments and keyword arguments for experiment number `i`.
"""
get_exp(i, args...; kwargs...) = (_get_exp(a, i) for a in (args..., kwargs.data...))
get_exp(i, args...; kwargs...) = (_get_exp(a, i) for a in (args..., values(kwargs)...))

# Find task iterator (number of sources and indices)
"""
Expand Down
11 changes: 4 additions & 7 deletions src/TimeModeling/Modeling/extended_source_interface_serial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ function extended_source_modeling(model_full::Model, srcData, recGeometry, recDa

# Load full geometry for out-of-core geometry containers
recGeometry = Geometry(recGeometry)

model = model_full

# Set up Python model structure
Expand All @@ -18,15 +17,13 @@ function extended_source_modeling(model_full::Model, srcData, recGeometry, recDa
end
end

# Load shot record if stored on disk
typeof(recData) == SegyIO.SeisCon && (recData = convert(Array{Float32,2}, recData[1].data))

# Remove receivers outside the modeling domain (otherwise leads to segmentation faults)
recGeometry, recData = remove_out_of_bounds_receivers(recGeometry, recData, model)
recGeometry, recData = remove_out_of_bounds_receivers(recGeometry, convert(Matrix{Float32}, recData), model)
weights = isnothing(weights) ? nothing : pad_array(weights, pad_sizes(model, options; so=0); mode=:zeros)

isnothing(weights) ? nothing : weights = pad_array(weights, pad_sizes(model, options; so=0); mode=:zeros)
# Devito interface
argout = devito_interface(modelPy, model, srcData, recGeometry, recData, weights, dm, options)
argout = devito_interface(modelPy, srcData, recGeometry, recData, weights, dm, options)

# Extend gradient back to original model size
if op=='J' && mode==-1 && options.limit_m==true
argout = extend_gradient(model_full, model, argout)
Expand Down
11 changes: 5 additions & 6 deletions src/TimeModeling/Modeling/fwi_objective_serial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,30 @@ function fwi_objective(model_full::Model, source::judiVector, dObs::judiVector,

# Set up Python model
modelPy = devito_model(model, options)
dtComp = get_dt(model; dt=options.dt_comp)
dtComp = convert(Float32, modelPy."critical_dt")

# Extrapolate input data to computational grid
qIn = time_resample(source.data[1], source.geometry, dtComp)[1]
obsd = typeof(dObs.data[1]) == SegyIO.SeisCon ? convert(Array{Float32,2}, dObs.data[1][1].data) : dObs.data[1]
dObserved = time_resample(obsd, dObs.geometry, dtComp)[1]
dObserved = time_resample(convert(Matrix{Float32}, dObs.data[1]), dObs.geometry, dtComp)[1]

# Set up coordinates
src_coords = setup_grid(source.geometry, model.n) # shifts source coordinates by origin
rec_coords = setup_grid(dObs.geometry, model.n) # shifts rec coordinates by origin


if options.optimal_checkpointing == true
argout1, argout2 = pycall(ac."J_adjoint_checkpointing", Tuple{Float32, Array{Float32, modelPy.dim}},
argout1, argout2 = pycall(ac."J_adjoint_checkpointing", Tuple{Float32, PyArray},
modelPy, src_coords, qIn,
rec_coords, dObserved, is_residual=false, return_obj=true, isic=options.isic,
t_sub=options.subsampling_factor, space_order=options.space_order)
elseif ~isempty(options.frequencies)
argout1, argout2 = pycall(ac."J_adjoint_freq", Tuple{Float32, Array{Float32, modelPy.dim}},
argout1, argout2 = pycall(ac."J_adjoint_freq", Tuple{Float32, PyArray},
modelPy, src_coords, qIn,
rec_coords, dObserved, is_residual=false, return_obj=true, isic=options.isic,
freq_list=options.frequencies, t_sub=options.subsampling_factor,
space_order=options.space_order)
else
argout1, argout2 = pycall(ac."J_adjoint_standard", Tuple{Float32, Array{Float32, modelPy.dim}},
argout1, argout2 = pycall(ac."J_adjoint_standard", Tuple{Float32, PyArray},
modelPy, src_coords, qIn,
rec_coords, dObserved, is_residual=false, return_obj=true,
t_sub=options.subsampling_factor, space_order=options.space_order,
Expand Down
13 changes: 6 additions & 7 deletions src/TimeModeling/Modeling/lsrtm_objective_serial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,30 @@ function lsrtm_objective(model_full::Model, source::judiVector, dObs::judiVector

# Set up Python model structure
modelPy = devito_model(model, options; dm=dm)
dtComp = get_dt(model; dt=options.dt_comp)
dtComp = convert(Float32, modelPy."critical_dt")

# Extrapolate input data to computational grid
qIn = time_resample(source.data[1],source.geometry,dtComp)[1]
obsd = typeof(dObs.data[1]) == SegyIO.SeisCon ? convert(Array{Float32,2}, dObs.data[1][1].data) : dObs.data[1]
dObserved = time_resample(obsd, dObs.geometry, dtComp)[1]
dObserved = time_resample(convert(Matrix{Float32}, dObs.data[1]), dObs.geometry, dtComp)[1]

# Set up coordinates
src_coords = setup_grid(source.geometry, model.n) # shifts source coordinates by origin
rec_coords = setup_grid(dObs.geometry, model.n) # shifts rec coordinates by origin

if options.optimal_checkpointing == true
argout1, argout2 = pycall(ac."J_adjoint_checkpointing", Tuple{Float32, Array{Float32, modelPy.dim}},
argout1, argout2 = pycall(ac."J_adjoint_checkpointing", Tuple{Float32, PyArray},
modelPy, src_coords, qIn,
rec_coords, dObserved, is_residual=false, return_obj=true,
t_sub=options.subsampling_factor, space_order=options.space_order,
born_fwd=true, nlind=nlind, isic=options.isic)
elseif ~isempty(options.frequencies)
argout1, argout2 = pycall(ac."J_adjoint_freq", Tuple{Float32, Array{Float32, modelPy.dim}},
argout1, argout2 = pycall(ac."J_adjoint_freq", Tuple{Float32, PyArray},
modelPy, src_coords, qIn,
rec_coords, dObserved, is_residual=false, return_obj=true, nlind=nlind,
freq_list=options.frequencies, t_sub=options.subsampling_factor,
space_order=options.space_order, born_fwd=true, isic=options.isic)
else
argout1, argout2 = pycall(ac."J_adjoint_standard", Tuple{Float32, Array{Float32, modelPy.dim}},
argout1, argout2 = pycall(ac."J_adjoint_standard", Tuple{Float32, PyArray},
modelPy, src_coords, qIn,
rec_coords, dObserved, is_residual=false, return_obj=true,
t_sub=options.subsampling_factor, space_order=options.space_order,
Expand All @@ -62,5 +61,5 @@ function lsrtm_objective(model_full::Model, source::judiVector, dObs::judiVector
argout2 = extend_gradient(model_full, model, argout2)
end

return Ref{Float32}(argout1), PhysicalParameter(argout2, model_full.d, model_full.o)
return Ref{Float32}(argout1), PhysicalParameter(argout2, model_full.d, model_full.o)
end
Loading

1 comment on commit a2eb2c8

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56662

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.6.5 -m "<description of version>" a2eb2c80aa6e0406879ce9f8fa5e683ce19bac1a
git push origin v2.6.5

Please sign in to comment.