-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Homogeneize generation params #428
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…n the models to use other attributes later
071d502
to
286668f
Compare
Interestingly, using an explicit generation config seems to mess up this quite a lot |
…ams set by default which slow down generations
…ammatically if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! I only reviewed it partially. Some comments below.
Co-authored-by: Nathan Habib <[email protected]> Co-authored-by: Albert Villanova del Moral <[email protected]>
Co-authored-by: Nathan Habib <[email protected]>
…ia the generationparams instead
Co-authored-by: Albert Villanova del Moral <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to have a coherent config, few nits and good to go
@@ -117,6 +122,7 @@ def __init__( | |||
self.precision = _get_dtype(config.dtype, config=self._config) | |||
|
|||
self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha) | |||
self.sampling_params = SamplingParams(**config.sampling_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does the config.sampling_params
come from ? I do not see them in VLLMModelConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after testing it indeed does not run
def __post_init__(self): | ||
super() | ||
|
||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.warning( | |
warnings.warn( |
super() | ||
|
||
logger.warning( | ||
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!" | |
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!", | |
FutureWarning, |
@@ -124,6 +127,8 @@ class BaseModelConfig: | |||
model at a quantized precision. Needed for 4-bit and 8-bit precision. | |||
trust_remote_code (bool): Whether to trust remote code during model | |||
loading. | |||
generation_parameters (GenerationParameters): Range of parameters which will affect the generation. | |||
generation_config (GenerationConfig): GenerationConfig object (only passed during manual creation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we have both generation paramaters and config ?
@@ -1297,6 +1348,15 @@ def _loglikelihood_single_token( | |||
return dataset.get_original_order(res) | |||
|
|||
|
|||
class BaseModel(TransformersModel): | |||
def __post_init__(self): | |||
super() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super() | |
super().__post_init__() |
@@ -44,7 +44,7 @@ def accelerate( # noqa C901 | |||
model_args: Annotated[ | |||
str, | |||
Argument( | |||
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/base_model.yaml)" | |||
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/transformers_model.yaml)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to rename the file
- examples/model_configs/base_model.yaml
to - examples/model_configs/transformers_model.yaml
if "generation" not in config_dict: | ||
return GenerationParameters() | ||
return GenerationParameters( | ||
early_stopping=config_dict["generation"].get("early_stopping", None), | ||
repetition_penalty=config_dict["generation"].get("repetition_penalty", None), | ||
frequency_penalty=config_dict["generation"].get("frequency_penalty", None), | ||
length_penalty=config_dict["generation"].get("length_penalty", None), | ||
presence_penalty=config_dict["generation"].get("presence_penalty", None), | ||
max_new_tokens=config_dict["generation"].get("max_new_tokens", None), | ||
min_new_tokens=config_dict["generation"].get("min_new_tokens", None), | ||
seed=config_dict["generation"].get("seed", None), | ||
stop_tokens=config_dict["generation"].get("stop_tokens", None), | ||
temperature=config_dict["generation"].get("temperature", None), | ||
top_k=config_dict["generation"].get("top_k", None), | ||
min_p=config_dict["generation"].get("min_p", None), | ||
top_p=config_dict["generation"].get("top_p", None), | ||
truncate_prompt=config_dict["generation"].get("truncate_prompt", None), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "generation" not in config_dict: | |
return GenerationParameters() | |
return GenerationParameters( | |
early_stopping=config_dict["generation"].get("early_stopping", None), | |
repetition_penalty=config_dict["generation"].get("repetition_penalty", None), | |
frequency_penalty=config_dict["generation"].get("frequency_penalty", None), | |
length_penalty=config_dict["generation"].get("length_penalty", None), | |
presence_penalty=config_dict["generation"].get("presence_penalty", None), | |
max_new_tokens=config_dict["generation"].get("max_new_tokens", None), | |
min_new_tokens=config_dict["generation"].get("min_new_tokens", None), | |
seed=config_dict["generation"].get("seed", None), | |
stop_tokens=config_dict["generation"].get("stop_tokens", None), | |
temperature=config_dict["generation"].get("temperature", None), | |
top_k=config_dict["generation"].get("top_k", None), | |
min_p=config_dict["generation"].get("min_p", None), | |
top_p=config_dict["generation"].get("top_p", None), | |
truncate_prompt=config_dict["generation"].get("truncate_prompt", None), | |
) | |
return GenerationParameters(**config_dict.get("generation", {})) |
args = { | ||
"presence_penalty": self.presence_penalty, | ||
"frequency_penalty": self.frequency_penalty, | ||
"repetition_penalty": self.repetition_penalty, | ||
"temperature": self.temperature, | ||
"top_p": self.top_p, | ||
"top_k": self.top_k, | ||
"min_p": self.min_p, | ||
"seed": self.seed, | ||
"length_penalty": self.length_penalty, | ||
"early_stopping": self.early_stopping, | ||
"stop": self.stop_tokens, | ||
"max_tokens": self.max_new_tokens, | ||
"min_tokens": self.min_new_tokens, | ||
"truncate_prompt_tokens": self.truncate_prompt, | ||
} | ||
return {k: v for k, v in args.items() if v is not None} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args = { | |
"presence_penalty": self.presence_penalty, | |
"frequency_penalty": self.frequency_penalty, | |
"repetition_penalty": self.repetition_penalty, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"top_k": self.top_k, | |
"min_p": self.min_p, | |
"seed": self.seed, | |
"length_penalty": self.length_penalty, | |
"early_stopping": self.early_stopping, | |
"stop": self.stop_tokens, | |
"max_tokens": self.max_new_tokens, | |
"min_tokens": self.min_new_tokens, | |
"truncate_prompt_tokens": self.truncate_prompt, | |
} | |
return {k: v for k, v in args.items() if v is not None} | |
return {k: v for k, v in asdict(self).items() if v is not None} |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
|
||
from dataclasses import dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from dataclasses import dataclass | |
from dataclasses import dataclass, asdict |
with open(model_args, "r") as f: | ||
config = yaml.safe_load(f)["model"] | ||
generation_parameters = GenerationParameters.from_dict(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth adding a from_path
method?
} | ||
return {k: v for k, v in args.items() if v is not None} | ||
|
||
def to_tgi_inferenceendpoint_dict(self) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe renaming to to_tgi_endpoint_dict
?
- TGI has the word
inference
included
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's because it's either tgi or inference endpoints. Maybe tgi_ie?
args = { | ||
"max_new_tokens": self.max_new_tokens, | ||
"min_new_tokens": self.min_new_tokens, | ||
"early_stopping": self.early_stopping or False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"early_stopping": self.early_stopping or False, | |
"early_stopping": self.early_stopping, |
Why or False
? If None, it will not be passed to transformers GenerationConfig
and it defaults to False:
This PR does 3 things:
I would put
system_prompt
,fewshot_seeds
, anduse_chat_template
in the GenerationParameters too since they are generation parameters logically, but it can be another PRCloses #16