Skip to content

Commit

Permalink
Merge pull request #32 from lpm0073/next
Browse files Browse the repository at this point in the history
refactor: validate_prompt_sequence()
  • Loading branch information
lpm0073 authored Nov 29, 2023
2 parents 630bdd6 + 8ea9d3c commit f56b850
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
28 changes: 19 additions & 9 deletions grader/grader.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,29 @@ def grade(self):
be processed with the legacy code below.
"""
try:
LCResponse(**self.assignment)
lc_response = LCResponse(**self.assignment)
lc_response.validate_response()
return self.grade_response()
except (ValidationError, TypeError):
pass
# FIX NOTE: need to map these to an applicable exception from grader.exceptions.
print("warning: assignment failed an un-mapped pydantic validation.")
except (
ResponseFailedError,
InvalidResponseStructureError,
ResponseFailedError,
IncorrectResponseValueError,
IncorrectResponseTypeError,
) as e:
return self.grade_response(e)

try:
self.validate()
except InvalidResponseStructureError as e:
return self.grade_response(e)
except ResponseFailedError as e:
return self.grade_response(e)
except IncorrectResponseValueError as e:
return self.grade_response(e)
except IncorrectResponseTypeError as e:
except (
ResponseFailedError,
InvalidResponseStructureError,
ResponseFailedError,
IncorrectResponseValueError,
IncorrectResponseTypeError,
) as e:
return self.grade_response(e)
return self.grade_response()
67 changes: 53 additions & 14 deletions grader/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field

from .exceptions import (
IncorrectResponseValueError,
InvalidResponseStructureError,
ResponseFailedError,
)


class LCRequestMetaData(BaseModel):
Expand Down Expand Up @@ -39,13 +45,6 @@ class LCChatMemory(BaseModel):

messages: List[LCMessage]

@model_validator(mode="after")
def validate_messages(self) -> "LCChatMemory":
"""Validate that chat memory contains at least 2 dicts"""
if len(self.messages) < 2:
raise ValueError("messages must contain at least 2 objects")
return self


class LCBody(BaseModel):
"""LangChain body"""
Expand All @@ -67,10 +66,50 @@ class LCResponse(BaseModel):
status_code: int = Field(..., alias="statusCode", ge=200, le=599)
body: LCBody

@field_validator("status_code")
@classmethod
def status_code_is_valid(cls, status_code):
def validate_status_code(self):
"""Validate that the status_code == 200"""
if status_code != 200:
raise ValueError(f"status_code must be 200, got {status_code}")
return status_code
if self.status_code != 200:
raise ResponseFailedError(f"status_code must be 200. received: {self.status_code}")

def validate_is_base64_encoded(self):
"""Validate that is_base64_encoded is False"""
if self.is_base64_encoded:
raise IncorrectResponseValueError("is_base64_encoded must be False")

def validate_prompt_sequence(self):
"""Validate that the prompt sequence begins with a human message and ends with an ai message"""
messages = self.body.chat_memory.messages
if len(messages) < 2:
raise InvalidResponseStructureError("messages must contain at least 2 objects")

prompt_1 = messages[0]
if prompt_1.type != MessageType.human:
raise IncorrectResponseValueError(
f"First message in prompt sequence must be of type {MessageType.human}. received: {prompt_1.type}"
)
prompt_2 = messages[1]
if prompt_2.type != MessageType.ai:
raise IncorrectResponseValueError(
f"Second message in prompt sequence must be of type {MessageType.ai}. received: {prompt_2.type}"
)

def validate_request_meta_data(self):
"""Validate the request_meta_data"""
request_meta_data = self.body.request_meta_data
if request_meta_data.lambda_name != "lambda_langchain":
raise IncorrectResponseValueError(
f"lambda_name must be langchain. received: {request_meta_data.lambda_name}"
)
if not request_meta_data.model.startswith("gpt-3.5"):
raise IncorrectResponseValueError(f"model must be gpt-3.5. received: {request_meta_data.model}")
if not request_meta_data.end_point == "ChatCompletion":
raise IncorrectResponseValueError(
f"end_point must be ChatCompletion. received: {request_meta_data.end_point}"
)

def validate_response(self):
"""Validate the response"""
self.validate_status_code()
self.validate_is_base64_encoded()
self.validate_prompt_sequence()
self.validate_request_meta_data()

0 comments on commit f56b850

Please sign in to comment.