Skip to content

Commit

Permalink
fix: update tests to the new API
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 9, 2024
1 parent 6cbed33 commit be4fdb1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/initialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Random: Random

using LuxCore: LuxCore

using ..Utils: is_extension_loaded, unwrap_val
using ..Utils: is_extension_loaded

get_pretrained_weights_path(name::Symbol) = get_pretrained_weights_path(string(name))
function get_pretrained_weights_path(name::String)
Expand Down
48 changes: 24 additions & 24 deletions test/vision_tests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@testitem "AlexNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
@testset "pretrained: $(pretrained)" for pretrained in [true, false]
model, ps, st = Vision.AlexNet(; pretrained)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.AlexNet(; pretrained)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -16,9 +16,9 @@ end

@testitem "ConvMixer" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, name in [:small, :base, :large]
model, ps, st = Vision.ConvMixer(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.ConvMixer(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 256, 256, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -30,9 +30,9 @@ end

@testitem "GoogLeNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
model, ps, st = Vision.GoogLeNet(; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.GoogLeNet(; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -44,9 +44,9 @@ end

@testitem "MobileNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, name in [:v1, :v2, :v3_small, :v3_large]
model, ps, st = Vision.MobileNet(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.MobileNet(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -58,9 +58,9 @@ end

@testitem "ResNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [18, 34, 50, 101, 152]
model, ps, st = Vision.ResNet(depth; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.ResNet(depth; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -72,9 +72,9 @@ end

@testitem "ResNeXt" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [50, 101, 152]
model, ps, st = Vision.ResNeXt(depth; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.ResNeXt(depth; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -90,9 +90,9 @@ end
false, true],
batchnorm in [false, true]

model, ps, st = Vision.VGG(depth; batchnorm, pretrained)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.VGG(depth; batchnorm, pretrained)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -106,9 +106,9 @@ end
@testitem "VisionTransformer" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, name in [:tiny, :small, :base]
# :large, :huge, :giant, :gigantic --> too large for CI
model, ps, st = Vision.VisionTransformer(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.VisionTransformer(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 256, 256, 3, 2) |> aType

@jet model(img, ps, st)
Expand Down

0 comments on commit be4fdb1

Please sign in to comment.