diff --git a/llms/providers/anthropic.py b/llms/providers/anthropic.py index a05c007..e2d5581 100644 --- a/llms/providers/anthropic.py +++ b/llms/providers/anthropic.py @@ -59,11 +59,19 @@ def _prepare_model_inputs( max_tokens: int = 300, stop_sequences: Optional[List[str]] = None, ai_prompt: str = "", + system_message: Union[str, None] = None, stream: bool = False, **kwargs, ) -> Dict: + if system_message is None: + system_prompts = "" + else: + if self.model != "claude-2": + raise ValueError("System message only available for Claude-2 model") + system_prompts = f"{system_message.rstrip()}\n\n" + formatted_prompt = ( - f"{anthropic.HUMAN_PROMPT}{prompt}{anthropic.AI_PROMPT}{ai_prompt}" + f"{system_prompts}{anthropic.HUMAN_PROMPT}{prompt}{anthropic.AI_PROMPT}{ai_prompt}" ) if history is not None: @@ -108,7 +116,7 @@ def complete( max_tokens: int = 300, stop_sequences: Optional[List[str]] = None, ai_prompt: str = "", - system_message: str = None, + system_message: Union[str, None] = None, **kwargs, ) -> Result: """ @@ -125,6 +133,7 @@ def complete( max_tokens=max_tokens, stop_sequences=stop_sequences, ai_prompt=ai_prompt, + system_message=system_message, **kwargs, ) @@ -148,6 +157,7 @@ async def acomplete( max_tokens: int = 300, stop_sequences: Optional[List[str]] = None, ai_prompt: str = "", + system_message: Union[str, None] = None, **kwargs, ): """ @@ -163,6 +173,7 @@ async def acomplete( max_tokens=max_tokens, stop_sequences=stop_sequences, ai_prompt=ai_prompt, + system_message=system_message, **kwargs, ) with self.track_latency(): @@ -186,6 +197,7 @@ def complete_stream( max_tokens: int = 300, stop_sequences: Optional[List[str]] = None, ai_prompt: str = "", + system_message: Union[str, None] = None, **kwargs, ) -> StreamResult: """ @@ -201,6 +213,7 @@ def complete_stream( max_tokens=max_tokens, stop_sequences=stop_sequences, ai_prompt=ai_prompt, + system_message=system_message, stream=True, **kwargs, ) @@ -224,6 +237,7 @@ async def acomplete_stream( max_tokens: int = 300, stop_sequences: Optional[List[str]] = None, ai_prompt: str = "", + system_message: Union[str, None] = None, **kwargs, ) -> AsyncStreamResult: """ @@ -239,6 +253,7 @@ async def acomplete_stream( max_tokens=max_tokens, stop_sequences=stop_sequences, ai_prompt=ai_prompt, + system_message=system_message, stream=True, **kwargs, )