From 66876c01ff04a7916d16a18fa8fee790f2b1fd96 Mon Sep 17 00:00:00 2001 From: lpm0073 Date: Tue, 28 Nov 2023 20:43:09 -0600 Subject: [PATCH 1/2] code langchain class validations --- grader/grader.py | 28 ++++++++++++++++-------- grader/langchain.py | 53 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/grader/grader.py b/grader/grader.py index 0f97284..98ef549 100644 --- a/grader/grader.py +++ b/grader/grader.py @@ -223,19 +223,29 @@ def grade(self): be processed with the legacy code below. """ try: - LCResponse(**self.assignment) + lc_response = LCResponse(**self.assignment) + lc_response.validate_response() return self.grade_response() except (ValidationError, TypeError): - pass + # FIX NOTE: need to map these to an applicable exception from grader.exceptions. + print("warning: assignment failed an un-mapped pydantic validation.") + except ( + ResponseFailedError, + InvalidResponseStructureError, + ResponseFailedError, + IncorrectResponseValueError, + IncorrectResponseTypeError, + ) as e: + return self.grade_response(e) try: self.validate() - except InvalidResponseStructureError as e: - return self.grade_response(e) - except ResponseFailedError as e: - return self.grade_response(e) - except IncorrectResponseValueError as e: - return self.grade_response(e) - except IncorrectResponseTypeError as e: + except ( + ResponseFailedError, + InvalidResponseStructureError, + ResponseFailedError, + IncorrectResponseValueError, + IncorrectResponseTypeError, + ) as e: return self.grade_response(e) return self.grade_response() diff --git a/grader/langchain.py b/grader/langchain.py index de173d3..60b9256 100644 --- a/grader/langchain.py +++ b/grader/langchain.py @@ -4,7 +4,9 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, model_validator + +from .exceptions import IncorrectResponseValueError, ResponseFailedError class LCRequestMetaData(BaseModel): @@ -67,10 +69,47 @@ class LCResponse(BaseModel): status_code: int = Field(..., alias="statusCode", ge=200, le=599) body: LCBody - @field_validator("status_code") - @classmethod - def status_code_is_valid(cls, status_code): + def validate_status_code(self): """Validate that the status_code == 200""" - if status_code != 200: - raise ValueError(f"status_code must be 200, got {status_code}") - return status_code + if self.status_code != 200: + raise ResponseFailedError(f"status_code must be 200. received: {self.status_code}") + + def validate_is_base64_encoded(self): + """Validate that is_base64_encoded is False""" + if self.is_base64_encoded: + raise IncorrectResponseValueError("is_base64_encoded must be False") + + def validate_prompt_sequence(self): + """Validate that the prompt sequence begins with a human message and ends with an ai message""" + messages = self.body.chat_memory.messages + prompt_1 = messages[0] + if prompt_1.type != MessageType.human: + raise IncorrectResponseValueError( + f"First message in prompt sequence must be of type {MessageType.human}. received: {prompt_1.type}" + ) + prompt_2 = messages[1] + if prompt_2.type != MessageType.ai: + raise IncorrectResponseValueError( + f"Second message in prompt sequence must be of type {MessageType.ai}. received: {prompt_2.type}" + ) + + def validate_request_meta_data(self): + """Validate the request_meta_data""" + request_meta_data = self.body.request_meta_data + if request_meta_data.lambda_name != "lambda_langchain": + raise IncorrectResponseValueError( + f"lambda_name must be langchain. received: {request_meta_data.lambda_name}" + ) + if not request_meta_data.model.startswith("gpt-3.5"): + raise IncorrectResponseValueError(f"model must be gpt-3.5. received: {request_meta_data.model}") + if not request_meta_data.end_point == "ChatCompletion": + raise IncorrectResponseValueError( + f"end_point must be ChatCompletion. received: {request_meta_data.end_point}" + ) + + def validate_response(self): + """Validate the response""" + self.validate_status_code() + self.validate_is_base64_encoded() + self.validate_prompt_sequence() + self.validate_request_meta_data() From 8ea9d3c6b4026b5967c043fe3250da25e08f0a58 Mon Sep 17 00:00:00 2001 From: lpm0073 Date: Tue, 28 Nov 2023 20:50:24 -0600 Subject: [PATCH 2/2] refactor: validate_prompt_sequence() --- grader/langchain.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/grader/langchain.py b/grader/langchain.py index 60b9256..846b557 100644 --- a/grader/langchain.py +++ b/grader/langchain.py @@ -4,9 +4,13 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field -from .exceptions import IncorrectResponseValueError, ResponseFailedError +from .exceptions import ( + IncorrectResponseValueError, + InvalidResponseStructureError, + ResponseFailedError, +) class LCRequestMetaData(BaseModel): @@ -41,13 +45,6 @@ class LCChatMemory(BaseModel): messages: List[LCMessage] - @model_validator(mode="after") - def validate_messages(self) -> "LCChatMemory": - """Validate that chat memory contains at least 2 dicts""" - if len(self.messages) < 2: - raise ValueError("messages must contain at least 2 objects") - return self - class LCBody(BaseModel): """LangChain body""" @@ -82,6 +79,9 @@ def validate_is_base64_encoded(self): def validate_prompt_sequence(self): """Validate that the prompt sequence begins with a human message and ends with an ai message""" messages = self.body.chat_memory.messages + if len(messages) < 2: + raise InvalidResponseStructureError("messages must contain at least 2 objects") + prompt_1 = messages[0] if prompt_1.type != MessageType.human: raise IncorrectResponseValueError(