diff --git a/gemma/gm/text/_chat_sampler.py b/gemma/gm/text/_chat_sampler.py index bed39b2..a658d61 100644 --- a/gemma/gm/text/_chat_sampler.py +++ b/gemma/gm/text/_chat_sampler.py @@ -119,6 +119,11 @@ def sampler(self) -> _sampler.Sampler: forbidden_tokens=self.forbidden_tokens, cache_length=self.cache_length, ) + + def resize_cache(self, new_cache_length:int): + object.__setattr__(self, 'cache_length', new_cache_length) + #reinitialization of the sampler + object.__setattr__(self, 'sampler', self.sampler) def chat( self,