Skip to content

Commit

Permalink
refactor: validate_prompt_sequence()
Browse files Browse the repository at this point in the history
  • Loading branch information
lpm0073 committed Nov 29, 2023
1 parent 66876c0 commit 8ea9d3c
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions grader/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field

from .exceptions import IncorrectResponseValueError, ResponseFailedError
from .exceptions import (
IncorrectResponseValueError,
InvalidResponseStructureError,
ResponseFailedError,
)


class LCRequestMetaData(BaseModel):
Expand Down Expand Up @@ -41,13 +45,6 @@ class LCChatMemory(BaseModel):

messages: List[LCMessage]

@model_validator(mode="after")
def validate_messages(self) -> "LCChatMemory":
"""Validate that chat memory contains at least 2 dicts"""
if len(self.messages) < 2:
raise ValueError("messages must contain at least 2 objects")
return self


class LCBody(BaseModel):
"""LangChain body"""
Expand Down Expand Up @@ -82,6 +79,9 @@ def validate_is_base64_encoded(self):
def validate_prompt_sequence(self):
"""Validate that the prompt sequence begins with a human message and ends with an ai message"""
messages = self.body.chat_memory.messages
if len(messages) < 2:
raise InvalidResponseStructureError("messages must contain at least 2 objects")

prompt_1 = messages[0]
if prompt_1.type != MessageType.human:
raise IncorrectResponseValueError(
Expand Down

0 comments on commit 8ea9d3c

Please sign in to comment.