Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Dec 11, 2024
1 parent c9b5086 commit 00540ef
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
1 change: 1 addition & 0 deletions examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(self, data: DataProto):
import hydra
from split_monkey_patch import fit


@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None)
def main(config):
if not ray.is_initialized():
Expand Down
20 changes: 10 additions & 10 deletions examples/split_placement/split_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def fit(self):
from omegaconf import OmegaConf

logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))

global_steps = 0

Expand Down Expand Up @@ -89,15 +89,15 @@ def fit(self):

# compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)
metrics['timing/adv'] = timer.last

# update critic
Expand All @@ -112,7 +112,7 @@ def fit(self):
with Timer(name='update_actor_call', logger=None) as timer:
actor_output = self.actor_rollout_wg.update_actor(batch)
metrics['timing/update_acto_call'] = timer.last

# NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class
with Timer(name='update_actor_critic', logger=None) as timer:
# NOTE: get the DataProtoFuture
Expand Down Expand Up @@ -149,7 +149,7 @@ def fit(self):

if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{global_steps}')
f'global_step_{global_steps}')
critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

Expand Down
5 changes: 4 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def create_resource_pool(self):
# Due to the Ray issue, we can only support max_colocate_count=1 for now.
# This means that each GPU can only have one process.
# We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name)
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
use_gpu=True,
max_colocate_count=1,
name_prefix=resource_pool_name)
self.resource_pool_dict[resource_pool_name] = resource_pool

def get_resource_pool(self, role: Role) -> RayResourcePool:
Expand Down

0 comments on commit 00540ef

Please sign in to comment.