diff --git a/grader/langchain.py b/grader/langchain.py index 60b9256..846b557 100644 --- a/grader/langchain.py +++ b/grader/langchain.py @@ -4,9 +4,13 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field -from .exceptions import IncorrectResponseValueError, ResponseFailedError +from .exceptions import ( + IncorrectResponseValueError, + InvalidResponseStructureError, + ResponseFailedError, +) class LCRequestMetaData(BaseModel): @@ -41,13 +45,6 @@ class LCChatMemory(BaseModel): messages: List[LCMessage] - @model_validator(mode="after") - def validate_messages(self) -> "LCChatMemory": - """Validate that chat memory contains at least 2 dicts""" - if len(self.messages) < 2: - raise ValueError("messages must contain at least 2 objects") - return self - class LCBody(BaseModel): """LangChain body""" @@ -82,6 +79,9 @@ def validate_is_base64_encoded(self): def validate_prompt_sequence(self): """Validate that the prompt sequence begins with a human message and ends with an ai message""" messages = self.body.chat_memory.messages + if len(messages) < 2: + raise InvalidResponseStructureError("messages must contain at least 2 objects") + prompt_1 = messages[0] if prompt_1.type != MessageType.human: raise IncorrectResponseValueError(