We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ChatGLMForConditionalGeneration部分需要额外继承transformers的GenerationMixin,
`class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin): def init(self, config: ChatGLMConfig, empty_init=True, device=None): super().init(config)
self.max_sequence_length = config.max_length # 最大序列长度 self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) # 使用 ChatGLMModel 类 self.config = config def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, ) -> Dict[str, Any]: # 更新 past_key_values _, model_kwargs["past_key_values"] = self._extract_past_from_model_output( outputs )
` self._extract_past_from_model_output方法和之前不一样,传入请删除standardize_cache_format,高版本没有这个参数,返回也变成了cache_name, past_key_values
把model_kwargs["past_key_values"] = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) 变成 _, model_kwargs["past_key_values"] = self._extract_past_from_model_output( outputs )就欧克了
The text was updated successfully, but these errors were encountered:
Sorry, something went wrong.
No branches or pull requests
ChatGLMForConditionalGeneration部分需要额外继承transformers的GenerationMixin,
`class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin):
def init(self, config: ChatGLMConfig, empty_init=True, device=None):
super().init(config)
`
self._extract_past_from_model_output方法和之前不一样,传入请删除standardize_cache_format,高版本没有这个参数,返回也变成了cache_name, past_key_values
把model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
变成
_, model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs
)就欧克了
The text was updated successfully, but these errors were encountered: