Skip to content

fix(pu): fix noise layer's usage based on the original paper #866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 3, 2025

Conversation

puyuan1996
Copy link
Collaborator

@puyuan1996 puyuan1996 commented Apr 17, 2025

Description

This pull request fixes the usage of Noisy Net in accordance with the original Noisy Net paper.
image

The key modifications are as follows:

  • Add set_noise_mode Function:
    A new helper function, set_noise_mode, is introduced to control whether the noise is enabled (enable_noise). This function is used to update noise settings in the network.

  • Add _reset_noise Method in DQN:
    A new _reset_noise method has been added to the DQN implementation. During each training step, the noise is reset and the corresponding noise is applied.

  • Model Weight and Noise Settings:

    • The training model, collection model, and evaluation model share the same weights.
    • During each training step, the model resets the noise and applies new noise. Once training steps conclude and the collection step begins, the noise added is the same as the one from the last training step.
    • For the evaluation model, no noise is applied.
  • Experimental Result:
    After fixing the Noisy Net implementation to be consistent with the paper's description, experimental results indicate that there is no significant performance difference whether Noisy Net is used or not.
    image

Related Issue

Check List

  • Merge the latest version of the source branch/repo and resolve all conflicts
  • Pass style check
  • Pass all tests

@puyuan1996 puyuan1996 added the bug Something isn't working label Apr 17, 2025
@@ -248,6 +248,8 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
set_noise_mode(self._learn_model, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use noisy_net to control this line

Another question: how to deal with target_model in noisy net

@puyuan1996 puyuan1996 changed the title fix(pu): fix noise layer's usage fix(pu): fix noise layer's usage based on the original paper Jun 3, 2025
@puyuan1996 puyuan1996 mentioned this pull request Jun 3, 2025
@@ -201,6 +202,11 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
# ====================
self._learn_model.train()
self._target_model.train()

# Set noise mode for NoisyNet for exploration in learning if enabled in config
set_noise_mode(self._learn_model, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use self._cfg.noisy_net to control this logic

@PaParaZz1 PaParaZz1 merged commit cf72cc0 into main Jun 3, 2025
19 of 33 checks passed
@PaParaZz1 PaParaZz1 deleted the fix-noise branch June 3, 2025 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants