Skip to content

Commit

Permalink
Merge pull request #33 from lpm0073/next
Browse files Browse the repository at this point in the history
do all grading validations with pydantic
  • Loading branch information
lpm0073 authored Nov 29, 2023
2 parents f56b850 + 10ddab5 commit 7288ac5
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 214 deletions.
2 changes: 1 addition & 1 deletion grader/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = "1.4.0"
__version__ = "1.5.0"
8 changes: 2 additions & 6 deletions grader/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ def main(filepath: str = None, output_folder: str = "out", potential_points: int
assignments = glob.glob(os.path.join(filepath, "*.json"))
for assignment_filename in assignments:
with open(assignment_filename, "r", encoding="utf-8") as f:
try:
assignment = json.load(f)
except json.JSONDecodeError:
print(f"warning: invalid JSON in assignment_filename: {assignment_filename}")
assignment = f.read()
grader = AutomatedGrader(assignment, potential_points=potential_points)
assignment = f.read()
grader = AutomatedGrader(assignment=assignment, potential_points=potential_points)
grade = grader.grade()
with open(
os.path.join(OUTPUT_FILE_PATH, f"{os.path.basename(assignment_filename)}"), "w", encoding="utf-8"
Expand Down
9 changes: 9 additions & 0 deletions grader/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ class ResponseFailedError(AGException):
def __init__(self, message):
self.message = message
super().__init__(self.message, penalty_pct=AGRubric.RESPONSE_FAILED_PENALTY_PCT)


VALID_MESSAGE_TYPES = [
"Success",
IncorrectResponseTypeError.__name__,
IncorrectResponseValueError.__name__,
InvalidResponseStructureError.__name__,
ResponseFailedError.__name__,
]
206 changes: 27 additions & 179 deletions grader/grader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
"""Provide a class for grading a submission against an assignment.""" ""

import json
import os

from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator

from .exceptions import (
VALID_MESSAGE_TYPES,
AGException,
IncorrectResponseTypeError,
IncorrectResponseValueError,
Expand All @@ -16,22 +16,6 @@
from .langchain import LCResponse


VALID_MESSAGE_TYPES = [
"Success",
IncorrectResponseTypeError.__name__,
IncorrectResponseValueError.__name__,
InvalidResponseStructureError.__name__,
ResponseFailedError.__name__,
]

HERE = os.path.abspath(os.path.dirname(__file__))
REQUIRED_KEYS_SPEC = "required-keys.json"
REQUIRED_KEYS_PATH = os.path.join(HERE, "data", REQUIRED_KEYS_SPEC)

HUMAN_PROMPT = {"content": "a prompt from a human", "additional_kwargs": {}, "type": "human", "example": False}
AI_RESPONSE = {"content": "a response from the AI", "additional_kwargs": {}, "type": "ai", "example": False}


class Grade(BaseModel):
"""
This is the base class for all Grader types. It provides the common interface and
Expand Down Expand Up @@ -63,149 +47,11 @@ def message_type_is_valid(cls, message_type):


# flake8: noqa: E701
class AutomatedGrader:
class AutomatedGrader(BaseModel):
"""Grade a submission against an assignment."""

def __init__(self, assignment, potential_points=100):
self._assignment = assignment
self._potential_points = potential_points
with open(REQUIRED_KEYS_PATH, "r", encoding="utf-8") as f: # pylint: disable=invalid-name
self.required_keys = json.load(f)

@property
def assignment(self):
"""Return the assignment."""
return self._assignment

@property
def potential_points(self):
"""Return the potential points for the assignment."""
return self._potential_points

def validate_keys(self, subject, control):
"""Validate that the subject has all the keys in the control dict."""
assignment_keys = set(subject.keys())
required_keys = set(control.keys())

if not required_keys.issubset(assignment_keys):
missing_keys = required_keys.difference(assignment_keys)
raise InvalidResponseStructureError(
f"The assignment is missing one or more required keys. missing: {missing_keys}"
)
return True

def validate_statuscode(self):
"""Validate that the assignment's statusCode is 200."""
if "statusCode" not in self.assignment:
raise InvalidResponseStructureError(f"The assignment must have a statusCode. assignment: {self.assignment}")
if not isinstance(self.assignment.get("statusCode"), int):
status_code_type = type(self.assignment.get("statusCode"))
raise IncorrectResponseTypeError(
f"The assignment's statusCode must be an integer. received: {status_code_type}"
)
status_code = self.assignment["statusCode"]
if not status_code == 200:
raise ResponseFailedError(f"The assignment's statusCode must be 200. received: {status_code}")
return True

def validate_base64encoded(self):
"""Validate that the assignment's isBase64Encoded is False."""
if "isBase64Encoded" not in self.assignment:
raise InvalidResponseStructureError(
f"The assignment must have a isBase64Encoded. assignment: {self.assignment}"
)
is_base64_encoded = self.assignment.get("isBase64Encoded")
if not isinstance(is_base64_encoded, bool):
is_base64_encoded_type = type(is_base64_encoded)
raise IncorrectResponseTypeError(
f"The assignment's base64Encoded must be a boolean. received: {is_base64_encoded_type}"
)
if self.assignment["isBase64Encoded"]:
raise IncorrectResponseValueError("The assignment's isBase64Encoded must be False.")

def validate_body(self):
"""Validate that the assignment's body is a dict with the correct keys."""
if "body" not in self.assignment:
raise InvalidResponseStructureError(f"The assignment must have a body. assignment: {self.assignment}")

body = self.assignment.get("body")
if not isinstance(body, dict):
body_type = type(body)
raise IncorrectResponseTypeError(f"The assignment's body must be a dict. Received {body_type}.")
if not "chat_memory" in body:
raise InvalidResponseStructureError(
f"The assignment's body must have a key named chat_memory. body: {body}"
)
if not "messages" in body["chat_memory"]:
raise InvalidResponseStructureError(
f"The assignment's body.chat_memory must has a key named messages. body: {body}"
)
messages = body["chat_memory"]["messages"]
if not isinstance(messages, list):
messages_type = type(messages)
raise IncorrectResponseTypeError(
f"The assignment's body.chat_memory.messages must be a list. Received {messages_type}."
)
if len(messages) < 2:
raise InvalidResponseStructureError(
f"The messages list must contain at least two elements. messages: {messages}"
)

for message in messages:
if not isinstance(message, dict):
raise InvalidResponseStructureError(
f"All elements in the messages list must be dictionaries. messages: {messages}"
)

human_prompt = messages[0]
ai_response = messages[1]

self.validate_keys(human_prompt, HUMAN_PROMPT)
self.validate_keys(ai_response, AI_RESPONSE)

if not human_prompt["type"] == "human":
raise IncorrectResponseValueError(f"The first message must be a human prompt. first prompt: {human_prompt}")
if not ai_response["type"] == "ai":
raise IncorrectResponseValueError(f"The second message must be an AI response. response: {ai_response}")

def validate_metadata(self):
"""Validate that the assignment's metadata is a dict with the correct keys."""
body = self.assignment.get("body")
request_meta_data = body["request_meta_data"]
if not isinstance(request_meta_data, dict):
meta_data_type = type(request_meta_data)
raise InvalidResponseStructureError(
f"The assignment must has a dict named request_meta_data. received: {meta_data_type}"
)
if request_meta_data.get("lambda") is None:
raise InvalidResponseStructureError(
f"The request_meta_data key lambda_langchain must exist. request_meta_data: {request_meta_data}"
)
if request_meta_data.get("model") is None:
raise InvalidResponseStructureError(
f"The request_meta_data key model must exist. request_meta_data: {request_meta_data}"
)
if request_meta_data.get("end_point") is None:
raise InvalidResponseStructureError(
f"The request_meta_data end_point must exist. request_meta_data: {request_meta_data}"
)

if not request_meta_data.get("lambda") == "lambda_langchain":
raise IncorrectResponseValueError(f"The request_meta_data.lambda must be lambda_langchain. body: {body}")
if not request_meta_data.get("model") == "gpt-3.5-turbo":
raise IncorrectResponseValueError(f"The request_meta_data.model must be gpt-3.5-turbo. body: {body}")
if not request_meta_data.get("end_point") == "ChatCompletion":
raise IncorrectResponseValueError(f"The request_meta_data.end_point must be ChatCompletion. body: {body}")

def validate(self):
"""Validate the assignment data structure."""
if not isinstance(self.assignment, dict):
raise InvalidResponseStructureError("The assignment must be a dictionary.")
self.validate_keys(self.assignment, self.required_keys)
self.validate_statuscode()
self.validate_base64encoded()
self.validate_body()
self.validate_metadata()
assignment: str = Field(..., description="The assignment to grade.")
potential_points: float = Field(100, description="The maximum number of points that can be awarded.", ge=0)

def grade_response(self, message: AGException = None):
"""Create a grade dict from the assignment."""
Expand All @@ -217,29 +63,32 @@ def grade_response(self, message: AGException = None):
return grade.model_dump()

def grade(self):
"""Grade the assignment.
This is an experimental usage of pydantic to validate the assignment.
Only two tests should pass, the rest should raise exceptions and then
be processed with the legacy code below.
"""
"""Grade the assignment."""
assignment_json: dict
lc_response: LCResponse

# 1.) attempt to load the assignment as JSON
try:
assignment_json = json.loads(self.assignment)
except json.JSONDecodeError as e:
try:
raise InvalidResponseStructureError("The assignment is not valid JSON") from e
except InvalidResponseStructureError as reraised_e:
return self.grade_response(reraised_e)

# 2.) attempt to validate the assignment using Pydantic
try:
lc_response = LCResponse(**assignment_json)
except (ValidationError, TypeError) as e:
try:
raise InvalidResponseStructureError("The assignment failed pydantic validation.") from e
except InvalidResponseStructureError as reraised_e:
return self.grade_response(reraised_e)

# 3.) validate the assignment
try:
lc_response = LCResponse(**self.assignment)
lc_response.validate_response()
return self.grade_response()
except (ValidationError, TypeError):
# FIX NOTE: need to map these to an applicable exception from grader.exceptions.
print("warning: assignment failed an un-mapped pydantic validation.")
except (
ResponseFailedError,
InvalidResponseStructureError,
ResponseFailedError,
IncorrectResponseValueError,
IncorrectResponseTypeError,
) as e:
return self.grade_response(e)

try:
self.validate()
except (
ResponseFailedError,
InvalidResponseStructureError,
Expand All @@ -248,4 +97,3 @@ def grade(self):
IncorrectResponseTypeError,
) as e:
return self.grade_response(e)
return self.grade_response()
12 changes: 6 additions & 6 deletions grader/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class LCRequestMetaData(BaseModel):
"""LangChain request meta data"""

lambda_name: str = Field(..., alias="lambda")
model: str
end_point: str
model: str = Field(...)
end_point: str = Field(...)
temperature: float = Field(..., ge=0, le=1)
max_tokens: int = Field(..., gt=0)

Expand Down Expand Up @@ -52,10 +52,10 @@ class LCBody(BaseModel):
chat_memory: LCChatMemory
output_key: Optional[str] = Field(None)
input_key: Optional[str] = Field(None)
return_messages: bool
human_prefix: str = Field("Human")
ai_prefix: str = Field("AI")
memory_key: str = Field("chat_history")
return_messages: Optional[bool] = Field(True)
human_prefix: Optional[str] = Field("Human")
ai_prefix: Optional[str] = Field("AI")
memory_key: Optional[str] = Field("chat_history")
request_meta_data: LCRequestMetaData


Expand Down
8 changes: 1 addition & 7 deletions grader/tests/init.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
# -*- coding: utf-8 -*-
"""Provide common functions for testing."""
import json


def get_event(filespec):
"""Reads a JSON file and returns the event"""
with open(filespec, "r", encoding="utf-8") as f: # pylint: disable=invalid-name
try:
event = json.load(f)
return event
except json.JSONDecodeError:
print(f"warning: invalid JSON in file: {filespec}")
return f.read()
return f.read()
Loading

0 comments on commit 7288ac5

Please sign in to comment.