Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DP with NAS and meet a ‘Model contains a trainable layer that Opacus doesn't currently support’ #576

Open
Drxdx opened this issue Mar 20, 2023 · 4 comments
Assignees

Comments

@Drxdx
Copy link

Drxdx commented Mar 20, 2023

📚 Documentation

i want to use dp with NAS, when i use the pre-trained DARTS model , ModuleValidator.fix() function does't work! anyone meet this problem?

@lucacorbucci
Copy link

Hi @Drxdx could you share the code that you're trying to run?

@Drxdx
Copy link
Author

Drxdx commented Mar 27, 2023

model = models.resnet18(num_classes=10)

#darts_v2_model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)

print("============",type(darts_v2_model))

print(type(model))
errors = ModuleValidator.validate(model, strict=False)
print(errors)

For a simple example,I use resnet18 for DP training, because resent18 contains BN layer, so it is not possible for differential privacy. If we use the ModuleValidator.fix(model, strcit = False) function, it will change the BN layer to the GN layer, so we can use it.But I use the model with Darts, where the fix() function is useless.

In this case, we have a function GradSampleModule.is_supported(m) that returns True for Conv and False for BN with Resnet18. But with Darts, both Conv and Bn return False.

This problem has been bothering me for a long time. I hope you can help me

@ffuuugor
Copy link
Contributor

Ok, this is actually interesting. I've investigated this a bit, and it seems like this problem could appear for any deserialized model.

The problem is, when the model is loaded by DartsSpace.load_searched_model('darts-v2', ...), the class object of its batch norms is different from the class object you get normally. This is confusing, so here's the example:

> model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)
> type(darts_v2_model.stages[0][0].preprocessor.pre0[2])
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>

> bn = torch.nn.modules.batchnorm.BatchNorm2d(2)
> type(bn)
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>

> type(darts_v2_model.stages[0][0].preprocessor.pre0[2]) == type(bn)
False

> id(type(darts_v2_model.stages[0][0].preprocessor.pre0[2]))
201567168

> id(type(bn))
98172448

This the leads to ModelValidator ignoring BatchNorms, because it checks the class object, not string representation.

I'm not exactly sure what about serialization process makes is create new class instances, and not sure how commonplace this is.

However, I don't see a good reason why we should keep references as keys, not strings. Any ideas why switching to strings could backfire?

cc @alexandresablayrolles @karthikprasad

@Drxdx
Copy link
Author

Drxdx commented Mar 28, 2023

The same problem in the First layer conv2d. Resnet18 is work, but the darts model is not work.
So I don't know how to fix this problem.
Maybe the darts is not compatible with opacus.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants