diff --git a/sotopia_generate.py b/sotopia_generate.py index 2eabb9c..fd7e8c1 100644 --- a/sotopia_generate.py +++ b/sotopia_generate.py @@ -82,7 +82,6 @@ def generate_action( # return AgentAction(action_type="none", argument="") @cache -@spaces.GPU(600) def prepare_model(model_name): compute_type = torch.float16 @@ -150,7 +149,7 @@ def obtain_chain_hf( chain = LLMChain(llm=hf, prompt=chat_prompt_template) return chain - +@spaces.GPU def generate( model_name: str, template: str,