Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gsolard committed Jul 30, 2024
1 parent 3aff1e8 commit 2a7d233
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/benchmark_llm_serving/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict
Returns:
dict : The payload
"""
raise NotImplemented("The subclass should implement this method")
raise NotImplemented("The subclass should implement this method") # type: ignore

def get_newly_generated_text(self, json_chunk: str) -> str:
def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text
Args:
Expand All @@ -34,7 +34,7 @@ def get_newly_generated_text(self, json_chunk: str) -> str:
Returns:
str : The newly generated text
"""
raise NotImplemented("The subclass should implement this method")
raise NotImplemented("The subclass should implement this method") # type: ignore

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict
"stream_options": {"include_usage": True}
}

def get_newly_generated_text(self, json_chunk: str) -> str:
def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text
Args:
Expand Down Expand Up @@ -173,7 +173,7 @@ def get_completions_headers(self) -> dict:
return {"Accept": "application/json",
"Content-Type": "application/json"}

def get_newly_generated_text(self, json_chunk: str) -> str:
def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text
Args:
Expand All @@ -192,8 +192,9 @@ def get_newly_generated_text(self, json_chunk: str) -> str:
def get_backend(backend_name: str) -> BackEnd:
implemented_backends = ["mistral", "happy_vllm"]
if backend_name not in implemented_backends:
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}")
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}")
if backend_name == "happy_vllm":
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
if backend_name == "mistral":
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
return BackEnd("not_implemented")

0 comments on commit 2a7d233

Please sign in to comment.