Skip to content

Commit

Permalink
Add system_message for Claude2
Browse files Browse the repository at this point in the history
  • Loading branch information
bkiat1123 committed Nov 24, 2023
1 parent b2a856d commit 50d4fbb
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions llms/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,19 @@ def _prepare_model_inputs(
max_tokens: int = 300,
stop_sequences: Optional[List[str]] = None,
ai_prompt: str = "",
system_message: Union[str, None] = None,
stream: bool = False,
**kwargs,
) -> Dict:
if system_message is None:
system_prompts = ""
else:
if self.model != "claude-2":
raise ValueError("System message only available for Claude-2 model")
system_prompts = f"{system_message.rstrip()}\n\n"

formatted_prompt = (
f"{anthropic.HUMAN_PROMPT}{prompt}{anthropic.AI_PROMPT}{ai_prompt}"
f"{system_prompts}{anthropic.HUMAN_PROMPT}{prompt}{anthropic.AI_PROMPT}{ai_prompt}"
)

if history is not None:
Expand Down Expand Up @@ -108,7 +116,7 @@ def complete(
max_tokens: int = 300,
stop_sequences: Optional[List[str]] = None,
ai_prompt: str = "",
system_message: str = None,
system_message: Union[str, None] = None,
**kwargs,
) -> Result:
"""
Expand All @@ -125,6 +133,7 @@ def complete(
max_tokens=max_tokens,
stop_sequences=stop_sequences,
ai_prompt=ai_prompt,
system_message=system_message,
**kwargs,
)

Expand All @@ -148,6 +157,7 @@ async def acomplete(
max_tokens: int = 300,
stop_sequences: Optional[List[str]] = None,
ai_prompt: str = "",
system_message: Union[str, None] = None,
**kwargs,
):
"""
Expand All @@ -163,6 +173,7 @@ async def acomplete(
max_tokens=max_tokens,
stop_sequences=stop_sequences,
ai_prompt=ai_prompt,
system_message=system_message,
**kwargs,
)
with self.track_latency():
Expand All @@ -186,6 +197,7 @@ def complete_stream(
max_tokens: int = 300,
stop_sequences: Optional[List[str]] = None,
ai_prompt: str = "",
system_message: Union[str, None] = None,
**kwargs,
) -> StreamResult:
"""
Expand All @@ -201,6 +213,7 @@ def complete_stream(
max_tokens=max_tokens,
stop_sequences=stop_sequences,
ai_prompt=ai_prompt,
system_message=system_message,
stream=True,
**kwargs,
)
Expand All @@ -224,6 +237,7 @@ async def acomplete_stream(
max_tokens: int = 300,
stop_sequences: Optional[List[str]] = None,
ai_prompt: str = "",
system_message: Union[str, None] = None,
**kwargs,
) -> AsyncStreamResult:
"""
Expand All @@ -239,6 +253,7 @@ async def acomplete_stream(
max_tokens=max_tokens,
stop_sequences=stop_sequences,
ai_prompt=ai_prompt,
system_message=system_message,
stream=True,
**kwargs,
)
Expand Down

0 comments on commit 50d4fbb

Please sign in to comment.