Skip to content

Commit

Permalink
fix flux extension
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 26, 2024
1 parent 9458f55 commit 2111769
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JUDI"
uuid = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
authors = ["Philipp Witte, Mathias Louboutin"]
version = "3.4.5"
version = "3.4.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -33,6 +33,7 @@ TimerOutputs = "0.5"
julia = "1.6"

[extensions]
CUDAJUDIExt = "CUDA"
FluxJUDIExt = "Flux"
JLD2JUDIExt = "JLD2"
ZygoteJUDIExt = "Zygote"
Expand All @@ -49,6 +50,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
test = ["Aqua", "JLD2", "Printf", "Test", "TimerOutputs", "Flux"]

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
10 changes: 10 additions & 0 deletions ext/CUDAJUDIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module CUDAJUDIExt

import JUDI: LazyPropagation, judiVector, eval_prop
isdefined(Base, :get_extension) ? (using CUDA) : (using ..CUDA)

CUDA.cu(F::LazyPropagation) = CUDA.cu(eval_prop(F))
CUDA.cu(x::Vector{Matrix{T}}) where T = [CUDA.cu(x[i]) for i=1:length(x)]
CUDA.cu(x::judiVector{T, Matrix{T}}) where T = judiVector{T, CUDA.CuMatrix{T}}(x.nsrc, x.geometry, CUDA.cu(x.data))

end
5 changes: 1 addition & 4 deletions ext/FluxJUDIExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
module FluxJUDIExt

import JUDI: LazyPropagation, judiVector, eval_prop
import JUDI: LazyPropagation
isdefined(Base, :get_extension) ? (using Flux) : (using ..Flux)

Flux.cpu(x::LazyPropagation) = Flux.cpu(eval_prop(x))
Flux.gpu(x::LazyPropagation) = Flux.gpu(eval_prop(x))
Flux.CUDA.cu(F::LazyPropagation) = Flux.CUDA.cu(eval_prop(F))
Flux.CUDA.cu(x::Vector{Matrix{T}}) where T = [Flux.CUDA.cu(x[i]) for i=1:length(x)]
Flux.CUDA.cu(x::judiVector{T, Matrix{T}}) where T = judiVector{T, Flux.CUDA.CuMatrix{T}}(x.nsrc, x.geometry, Flux.CUDA.cu(x.data))

end
6 changes: 6 additions & 0 deletions src/JUDI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ function __init__()
@info "Flux compat enabled"
include("../ext/FluxJUDIExt.jl")
end

# Additional Flux compat if in use
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
@info "CUDA compat enabled"
include("../ext/CUDAJUDIExt.jl")
end
end

end
Expand Down

0 comments on commit 2111769

Please sign in to comment.