Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add seller env, self-judge pipeline and mcts/alphazero config #276

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a64f1f7
feature(pu): add seller_env and its mcts bot and alphazero pipeline
PaParaZz1 Aug 20, 2024
0bbb54d
polish(pu): delete demo_bkp
PaParaZz1 Aug 20, 2024
68bb482
polish(pu): polish seller_env
dyyoungg Aug 21, 2024
d296a94
polish(pu): polish replay csv operations
dyyoungg Aug 21, 2024
5ca61f9
polish(pu): polish seller_env and mcts bot
dyyoungg Aug 23, 2024
98164ed
fix(pu): fix mcts bot
dyyoungg Aug 25, 2024
d31c09c
fix(pu): fix seller env
dyyoungg Aug 26, 2024
5ee961d
polish(pu): polish seller env and baselines
dyyoungg Aug 26, 2024
9c02e78
polish(pu): fix order of seed() and reset() in seller env
dyyoungg Aug 27, 2024
353eede
fix(pu): fix next_obs str bug
dyyoungg Aug 28, 2024
77e7d5b
fix(pu): fix alphazero eval_return bug and polish seller env
dyyoungg Sep 4, 2024
19ed1d3
polish(pu): polish seller_env
dyyoungg Sep 4, 2024
7908517
polish(pu): use cache simulate_env in az tree
dyyoungg Sep 4, 2024
7e8e1be
polish(pu): use api_pool
dyyoungg Sep 5, 2024
912663b
polish(pu): polish prompt_template
dyyoungg Sep 6, 2024
e184956
fix(pu): fix seller_env seed config
dyyoungg Sep 9, 2024
2c77c31
fix(pu): fix seller_env seed config
dyyoungg Sep 9, 2024
0c9b051
polish(pu): use BAAI/bge-m3 as upper lm
dyyoungg Sep 9, 2024
452d87d
feature(pu): add dynamic_action_space option
dyyoungg Sep 10, 2024
d903cda
polish(pu): add retry in Commander
dyyoungg Sep 10, 2024
7c031d2
fix(pu): fix seed() in seller_env
dyyoungg Sep 10, 2024
76efee1
polish(pu): add lmdeploy agent option
dyyoungg Sep 11, 2024
3c89206
polish(pu): polish configs
dyyoungg Sep 12, 2024
4fb10d4
refactor(pu): polish seller env related mcts/az/configs
dyyoungg Sep 19, 2024
57c78bb
polish(pu): polish seller env configs
dyyoungg Sep 19, 2024
a46b153
Merge branch 'main' of https://github.com/opendilab/LightZero into po…
dyyoungg Sep 19, 2024
b18f755
feature(whl): add seller_env readme
puyuan1996 Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def train_alphazero(
# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
print('load model from: %s' % model_path)

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
Expand All @@ -87,6 +88,14 @@ def train_alphazero(
exp_name=cfg.exp_name,
)

# TODO: for debug
# stop, reward = evaluator.eval(
# learner.save_checkpoint,
# learner.train_iter,
# collector.envstep,
# )
# import sys; sys.exit(0)

# ==============================================================
# Main loop
# ==============================================================
Expand Down Expand Up @@ -118,6 +127,12 @@ def train_alphazero(
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
new_data = sum(new_data, [])

if policy_config.simulation_env_id == 'seller':
for i in range(len(new_data)):
new_data[i]['obs']['observation'] = str(new_data[i]['obs']['observation'])
new_data[i]['next_obs']['observation'] = str(new_data[i]['next_obs']['observation'])

if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
Expand Down
392 changes: 392 additions & 0 deletions lzero/mcts/ptree/ptree_az_seller.py

Large diffs are not rendered by default.

150 changes: 150 additions & 0 deletions lzero/model/alphazero_model_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Overview:
BTW, users can refer to the unittest of these model templates to learn how to use them.
"""
from typing import Optional, Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import MODEL_REGISTRY
from FlagEmbedding import BGEM3FlagModel

try:
from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError:
import sys
from ditk import logging
logging.warning("not found transformer, please install it using: pip install transformers")
sys.exit(1)


@MODEL_REGISTRY.register('AlphaZeroModel')
class AlphaZeroModel(nn.Module):

def __init__(
self,
add_linear: bool = True,
embedding_size: int = 1024,
action_space_size: int = 9
) -> None:
super().__init__()
self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

if add_linear:
# Add an additional small, adjustable linear layer on top of BERT tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(1024, embedding_size)
else:
self.linear = None

value_support_size= 1 # TODO
self.value_head = nn.Linear(self.embedding_size, value_support_size)
self.policy_head = nn.Linear(self.embedding_size, action_space_size)

def _calc_embedding(self, x: list) -> torch.Tensor:
sentence_embedding = self.model.encode(x, batch_size=32, max_length=8192, )['dense_vecs'] # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
sentence_embedding = torch.from_numpy(sentence_embedding).to(self.value_head.weight.device).float()

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: list, candidate_samples: list = None) -> dict:
state_embedding = self._calc_embedding(train_samples)

policy_logits = self.policy_head(state_embedding) # len(input_list) x embedding_size

if self.value_head:
value = self.value_head(state_embedding) # len(input_list) x embedding_size

return policy_logits, value

def compute_policy_value(self, train_samples: list, candidate_samples: list = None) -> Tuple[torch.Tensor, torch.Tensor]:
logit, value = self.forward(train_samples, candidate_samples)
prob = torch.nn.functional.softmax(logit, dim=-1)
return prob, value

def compute_logp_value(self, train_samples: list, candidate_samples: list = None) -> Tuple[torch.Tensor, torch.Tensor]:
logit, value = self.forward(train_samples, candidate_samples)
# use log_softmax to calculate log probability
log_prob = F.log_softmax(logit, dim=-1)
return log_prob, value



@MODEL_REGISTRY.register('AlphaZeroModelBert')
class AlphaZeroModelBert(nn.Module):

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = True,
embedding_size: int = 128,
freeze_encoder: bool = True,
action_space_size: int = 9
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
for param in self.model.parameters():
param.requires_grad = False

if add_linear:
# Add an additional small, adjustable linear layer on top of BERT tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
else:
self.linear = None

value_support_size= 1 # TODO
self.value_head = nn.Linear(self.embedding_size, value_support_size) # 768 for bert-base-uncased, distilbert-base-uncased
self.policy_head = nn.Linear(self.embedding_size, action_space_size) # 768 for bert-base-uncased, distilbert-base-uncased

def _calc_embedding(self, x: list) -> torch.Tensor:
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
# exactly ``max_length``, which can enable batch-wise computing.
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
output = self.model(**input, output_hidden_states=True)
# Get last layer hidden states
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: list, candidate_samples: list = None) -> dict:
state_embedding = self._calc_embedding(train_samples)
policy_logits = self.policy_head(state_embedding) # len(input_list) x embedding_size

if self.value_head:
value = self.value_head(state_embedding) # len(input_list) x embedding_size

return policy_logits, value

def compute_policy_value(self, train_samples: list, candidate_samples: list = None) -> Tuple[torch.Tensor, torch.Tensor]:
logit, value = self.forward(train_samples, candidate_samples)
prob = torch.nn.functional.softmax(logit, dim=-1)
return prob, value

def compute_logp_value(self, train_samples: list, candidate_samples: list = None) -> Tuple[torch.Tensor, torch.Tensor]:
logit, value = self.forward(train_samples, candidate_samples)
# use log_softmax to calculate log probability
log_prob = F.log_softmax(logit, dim=-1)
return log_prob, value




Loading