diff --git a/src/initialize.jl b/src/initialize.jl index 8cd919c..16da6ff 100644 --- a/src/initialize.jl +++ b/src/initialize.jl @@ -7,7 +7,7 @@ using Random: Random using LuxCore: LuxCore -using ..Utils: is_extension_loaded, unwrap_val +using ..Utils: is_extension_loaded get_pretrained_weights_path(name::Symbol) = get_pretrained_weights_path(string(name)) function get_pretrained_weights_path(name::String) diff --git a/test/vision_tests.jl b/test/vision_tests.jl index f5d1595..9945050 100644 --- a/test/vision_tests.jl +++ b/test/vision_tests.jl @@ -1,9 +1,9 @@ @testitem "AlexNet" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES @testset "pretrained: $(pretrained)" for pretrained in [true, false] - model, ps, st = Vision.AlexNet(; pretrained) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.AlexNet(; pretrained) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 224, 224, 3, 2) |> aType @jet model(img, ps, st) @@ -16,9 +16,9 @@ end @testitem "ConvMixer" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, name in [:small, :base, :large] - model, ps, st = Vision.ConvMixer(name; pretrained=false) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.ConvMixer(name; pretrained=false) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 256, 256, 3, 2) |> aType @jet model(img, ps, st) @@ -30,9 +30,9 @@ end @testitem "GoogLeNet" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES - model, ps, st = Vision.GoogLeNet(; pretrained=false) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.GoogLeNet(; pretrained=false) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 224, 224, 3, 2) |> aType @jet model(img, ps, st) @@ -44,9 +44,9 @@ end @testitem "MobileNet" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, name in [:v1, :v2, :v3_small, :v3_large] - model, ps, st = Vision.MobileNet(name; pretrained=false) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.MobileNet(name; pretrained=false) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 224, 224, 3, 2) |> aType @jet model(img, ps, st) @@ -58,9 +58,9 @@ end @testitem "ResNet" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, depth in [18, 34, 50, 101, 152] - model, ps, st = Vision.ResNet(depth; pretrained=false) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.ResNet(depth; pretrained=false) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 224, 224, 3, 2) |> aType @jet model(img, ps, st) @@ -72,9 +72,9 @@ end @testitem "ResNeXt" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, depth in [50, 101, 152] - model, ps, st = Vision.ResNeXt(depth; pretrained=false) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.ResNeXt(depth; pretrained=false) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 224, 224, 3, 2) |> aType @jet model(img, ps, st) @@ -90,9 +90,9 @@ end false, true], batchnorm in [false, true] - model, ps, st = Vision.VGG(depth; batchnorm, pretrained) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.VGG(depth; batchnorm, pretrained) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 224, 224, 3, 2) |> aType @jet model(img, ps, st) @@ -106,9 +106,9 @@ end @testitem "VisionTransformer" setup=[SharedTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, name in [:tiny, :small, :base] # :large, :huge, :giant, :gigantic --> too large for CI - model, ps, st = Vision.VisionTransformer(name; pretrained=false) - ps = ps |> dev - st = Lux.testmode(st) |> dev + model = Vision.VisionTransformer(name; pretrained=false) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + st = Lux.testmode(st) img = randn(Float32, 256, 256, 3, 2) |> aType @jet model(img, ps, st)