Skip to content

Commit

Permalink
Enzyme: bump version and mark models as working [test] (#2439)
Browse files Browse the repository at this point in the history
* Enzyme: bump version and mark models as working [test]

* Update Project.toml

* Update Project.toml

* Update enzyme.jl

* Mark transpose as not supported
  • Loading branch information
wsmoses authored May 11, 2024
1 parent 26c9acf commit 11f3fca
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Adapt = "3, 4"
CUDA = "4, 5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.11"
Enzyme = "0.12.4"
FiniteDifferences = "0.12"
Functors = "0.4"
MLUtils = "0.4"
Expand Down
3 changes: 2 additions & 1 deletion test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ end
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
]

for (model, x, name) in models_xs
Expand Down Expand Up @@ -164,7 +165,7 @@ end
device = Flux.get_device()

models_xs = [
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# Pending https://github.com/FluxML/NNlib.jl/issues/565
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
]

Expand Down

0 comments on commit 11f3fca

Please sign in to comment.