From 0054956a97ae521fb1bf91b64b9f133ba753fd7d Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 19 Dec 2023 13:02:05 +0100 Subject: [PATCH] fix --- server/lorax_server/models/causal_lm.py | 2 +- server/lorax_server/utils/sources/hub.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 5a3d226c8..59ad07f48 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -533,7 +533,7 @@ def decode(self, generated_ids: List[int]) -> str: ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index 799a2a171..156401749 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -11,6 +11,7 @@ from huggingface_hub.utils import ( LocalEntryNotFoundError, EntryNotFoundError, # Import here to ease try/except in other part of the lib + RevisionNotFoundError, ) from .source import BaseModelSource, try_to_load_from_cache