Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing to use custom activation before computing the loss values #93

Open
lorenzo-consoli opened this issue Jan 24, 2024 · 3 comments

Comments

@lorenzo-consoli
Copy link

Hello,

It could be good to update the LRFinder object so to allow custom final activations to be computed before the loss.
I changed this locally in my site-packages directory of the environment im using to overcome the problem, as I need a log_softmax activation to be computed before the loss.

@NaleRaphael
Copy link
Contributor

Hi @FMGS666
It seems a question similar to #69, maybe these 2 approaches could fit your use case?

Said the forward pass below is the desired format:

# Forward pass
outputs = model(inputs)
outputs = F.log_softmax(outputs)
loss = criterion(outputs, labels)
  1. use a wrapper to include further operations:

    class LossFunctionWrapper(nn.Module):
        def __init__(self, loss_func):
            super().__init__()
            self.loss_func = loss_func
    
        def forward(self, outputs, labels):
            outputs = F.log_softmax(outputs)    # <--
            return self.loss_func(outputs, labels)

    Then pass this wrapper for loss function to LRFinder:

    loss_func_wrapper = LossFunctionWrapper(original_loss_func)
    lr_finder = LRFinder(model, optimizer, loss_func_wrapper, device='cuda')
  2. overriding LRFinder._train_batch():

    class MyLRFinder(LRFinder):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
    
        def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=True):
            # ...
    
            inputs, labels = next(train_iter)
            inputs, labels = self._move_to_device(
              inputs, labels, non_blocking=non_blocking_transfer
            )
    
            # Modified forward pass
            outputs = self.model(inputs)
            outputs = F.log_softmax(outputs)    # <--
            loss = self.criterion(outputs, labels)
    
            # ...

If it does not work as expected, please consider sharing the patch you made here, I'll help you resolve it.
And if there is any further problem, please feel free to let me know.

@lorenzo-consoli
Copy link
Author

My approach is more similar to the second one, as I directly patched the _train_batch method. It seemed more natural to me to allow the users to call an optional final activation rather than having to wrap the loss module itself.
Anyway, I could fork the repository and make a PR with the change (it'll be just few lines of code), and you can check it out yourself. Thank you for your reply, and have a nice day!

@NaleRaphael
Copy link
Contributor

Yeah, I think you can talk with David about this idea if you found this use case is either more commonly recently or can reduce the difficulty of using this library for more complex cases.

By the way, the reasons why I would recommend using those 2 above-mentioned approaches are:

  • it remains the flexibility of customization without needing to modify library.
  • it makes less surprises as LRFinder internal almost follows the conventional training pipeline as below (or like the one in this official tutorial).
    for i, batch in enumerate(data_loader):
        inputs, targets = batch
    
        outputs = model(inputs)
        loss = loss_func(outputs, targets)
    
        loss.backward()
        optimizer.step()
  • it tries to keep the codebase simple and minimal as we cannot foresee the complexity of training setup in the future.
  • (especially for wrapper classes) it should be less error-prone with complex training setup. Because data loader/model/loss function are wrapped into instances, so these instances can work the same as usual in both LRFinder._train_batch() and LRFinder._validate() when user adopts Leslie Smith's approach for learning rate search.

Yet these are just my opinions, and hope it can provide development context of this library to you.
Still appreciate your suggestion, have a nice day too!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants