Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeedEngineWrapper.backward() does a bit too much #2951

Open
zhc7 opened this issue Jul 22, 2024 · 13 comments
Open

DeepSpeedEngineWrapper.backward() does a bit too much #2951

zhc7 opened this issue Jul 22, 2024 · 13 comments
Labels
enhancement New feature or request feature request Request for a new feature to be added to Accelerate

Comments

@zhc7
Copy link

zhc7 commented Jul 22, 2024

The source code of DeepSpeedEngineWrapper:

class DeepSpeedEngineWrapper:
    """
    Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.

    Args:
        engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap
    """

    def __init__(self, engine):
        self.engine = engine

    def backward(self, loss, **kwargs):
        # runs backpropagation and handles mixed precision
        self.engine.backward(loss, **kwargs)

        # Deepspeed's `engine.step` performs the following operations:
        # - gradient accumulation check
        # - gradient clipping
        # - optimizer step
        # - zero grad
        # - checking overflow
        # - lr_scheduler step (only if engine.lr_scheduler is not None)
        self.engine.step()
        # and this plugin overrides the above calls with no-ops when Accelerate runs under
        # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple
        # training loop that works transparently under many training regimes.

My question is: Why do we need to do self.engine.step() here immediately? This behavior zeros grad and change the parameter without noticing the user. It might be out of expectation. Since backward step is internally binded with zeroing grad and changing parameter, this blocks users from checking the gradient or parameter manually before stepping.

I know deepspeed-wrapped models can't be seen as normal models, but this behavior still elimiates a lot of flexibility.

@zhc7 zhc7 changed the title DeepSpeedEngineWrapper.backword() does a bit too much DeepSpeedEngineWrapper.backward() does a bit too much Jul 22, 2024
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@muellerzr
Copy link
Collaborator

Hitting this snag right now actually, will see what we decide to do

@muellerzr
Copy link
Collaborator

We're working with the DS team to try and remove the engine entirely, however as a user you can always call model.engine.backward() etc manually without harm in accelerate

@nom
Copy link

nom commented Sep 23, 2024

Somehow

loss = loss / accelerator.gradient_accumulation_steps
accelerator.deepspeed_engine_wrapped.engine.backward(loss)

does not give equivalent results to
accelerator.backward(loss)

with deepspeed. What gives?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@LinB203
Copy link

LinB203 commented Oct 27, 2024

same question here.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@GreenWindow1997
Copy link

same question here. It's indeed very hard to understand why it was designed this way. I believe accelerator.backward(loss) should only perform the backward operation, and other steps should be written outside this function in a more standard and understandable manner.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@LinB203
Copy link

LinB203 commented Dec 16, 2024

Not stale.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Nirmal-Adhikari-hub
Copy link

self.engine.module.named_parameters()
might have the gradients you are looking for. I am not sure about this as I am not the person who knows enough to answer these esoteric questions but I also went through a similar problem of wanting to observe the gradients of the model after the backward pass and found out this solution which worked for me.

I hope it helps and if it did glad to be of any help to you.

Happy Coding!

@muellerzr muellerzr reopened this Feb 11, 2025
@muellerzr
Copy link
Collaborator

So for a bit more context, we've been waiting on this PR to happen, so hopefully we can give more flexibility soon: deepspeedai/DeepSpeed#7018

@muellerzr muellerzr added enhancement New feature or request feature request Request for a new feature to be added to Accelerate labels Feb 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request feature request Request for a new feature to be added to Accelerate
Projects
None yet
Development

No branches or pull requests

6 participants