Skip to content

Commit

Permalink
fix: try disabling force_preserve
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 13, 2024
1 parent 4f1d863 commit c30b4da
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@ Utils.is_extension_loaded(::Val{:Metalhead}) = true

function Vision.ResNetMetalhead(depth::Int; pretrained::Bool=false)
@argcheck depth in (18, 34, 50, 101, 152)
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ResNet(
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.ResNet(
depth; pretrain=pretrained).layers)
end

function Vision.ResNeXtMetalhead(
depth::Int; cardinality=32, base_width=nothing, pretrained::Bool=false)
@argcheck depth in (50, 101, 152)
base_width = base_width === nothing ? (depth == 101 ? 8 : 4) : base_width
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ResNeXt(
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.ResNeXt(
depth; pretrain=pretrained, cardinality, base_width).layers)
end

function Vision.GoogLeNetMetalhead(; pretrained::Bool=false)
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.GoogLeNet(;
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.GoogLeNet(;
pretrain=pretrained).layers)
end

function Vision.DenseNetMetalhead(depth::Int; pretrained::Bool=false)
@argcheck depth in (121, 161, 169, 201)
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.DenseNet(
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.DenseNet(
depth; pretrain=pretrained).layers)
end

function Vision.MobileNetMetalhead(name::Symbol; pretrained::Bool=false)
@argcheck name in (:v1, :v2, :v3_small, :v3_large)
adaptor = FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)
adaptor = FromFluxAdaptor(; preserve_ps_st=pretrained)
model = if name == :v1
adaptor(Metalhead.MobileNetv1(; pretrain=pretrained).layers)
elseif name == :v2
Expand All @@ -51,18 +51,18 @@ end

function Vision.ConvMixerMetalhead(name::Symbol; pretrained::Bool=false)
@argcheck name in (:base, :large, :small)
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ConvMixer(
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.ConvMixer(
name; pretrain=pretrained).layers)
end

function Vision.SqueezeNetMetalhead(; pretrained::Bool=false)
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.SqueezeNet(;
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.SqueezeNet(;
pretrain=pretrained).layers)
end

function Vision.WideResNetMetalhead(depth::Int; pretrained::Bool=false)
@argcheck depth in (18, 34, 50, 101, 152)
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.WideResNet(
return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.WideResNet(
depth; pretrain=pretrained).layers)
end

Expand Down

0 comments on commit c30b4da

Please sign in to comment.