Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Quality Checks #157

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# this target runs checks on all files
quality:
ruff check ochat
ruff format --check ochat

# this target runs checks on all files and potentially modifies some of them
style:
ruff check ochat --fix
ruff format ochat
5 changes: 2 additions & 3 deletions ochat/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import torch
import transformers

from ochat.config.model_config import ModelConfig
from ochat.config.conversation_template import Message, Conversation, ConversationTemplate
import ochat.models

from ochat.config.conversation_template import Conversation, ConversationTemplate, Message
from ochat.config.model_config import ModelConfig

_V3_2_PREFIXES = {
# OpenAI mapping
Expand Down
2 changes: 1 addition & 1 deletion ochat/config/conversation_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, Iterable, List, Dict
from typing import Callable, Iterable, List, Optional

from pydantic import BaseModel

Expand Down
14 changes: 6 additions & 8 deletions ochat/data/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.jsonl --tokenizer-name HF_REPO_NAME --out-dir .
"""

from typing import Optional
import argparse
import os
import random

import ray
import orjson
import pyarrow
import ray
from pyarrow import parquet


PAD_TOKEN_ID = 0


Expand Down Expand Up @@ -106,11 +104,11 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_
pyarrow.field("total_length", pyarrow.int32()),
pyarrow.field("num_seqs", pyarrow.float32()),

pyarrow.field(f"seqlens", pyarrow.list_(pyarrow.int32())),
pyarrow.field(f"nz_input_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field(f"nz_position_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field(f"nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field(f"nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32()))
pyarrow.field("seqlens", pyarrow.list_(pyarrow.int32())),
pyarrow.field("nz_input_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field("nz_position_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field("nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field("nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32()))
]

schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)})
Expand Down
13 changes: 6 additions & 7 deletions ochat/evaluation/conv_eval.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import OrderedDict
import signal
import argparse
import os
import json
import re
import signal
import subprocess
import argparse
import time
import requests
import re
import coolname
from typing import OrderedDict

import coolname
import requests

MAX_CONTEXT = 4096

Expand Down
4 changes: 2 additions & 2 deletions ochat/evaluation/convert_to_evalplus.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import os
import orjson

from glob import glob

import orjson


def convert_to_evalplus(results_path: str, output_path: str):
os.makedirs(output_path, exist_ok=True)
Expand Down
12 changes: 6 additions & 6 deletions ochat/evaluation/grading/math_grader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
Call grade_answer(given_answer: str, ground_truth: str).
"""
import re

import sympy
from pylatexenc import latex2text
from sympy.parsing import sympy_parser

from ochat.evaluation.grading import math_normalize


# sympy might hang -- we don't care about trying to be lenient in these cases
BAD_SUBSTRINGS = ["^{", "^("]
BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"]
TUPLE_CHARS = "()[]"


Expand Down Expand Up @@ -93,7 +93,7 @@ def _inject_implicit_mixed_number(step: str):

def _strip_properly_formatted_commas(expr: str):
# We want to be careful because we don't want to strip tuple commas
p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)")
while True:
next_expr = p1.sub("\\1\\3\\4", expr)
if next_expr == expr:
Expand All @@ -108,7 +108,7 @@ def _normalize(expr: str) -> str:
return None

# Remove enclosing `\text{}`.
m = re.search("^\\\\text\{(?P<text>.+?)\}$", expr)
m = re.search("^\\\\text\\{(?P<text>.+?)\\}$", expr)
if m is not None:
expr = m.group("text")

Expand Down Expand Up @@ -141,8 +141,8 @@ def _normalize(expr: str) -> str:
"inch",
"yard",
]:
expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
expr = re.sub(f"\^ *\\\\circ", "", expr)
expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
expr = re.sub("\\^ *\\\\circ", "", expr)

if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
expr = expr[1:-1]
Expand Down
4 changes: 2 additions & 2 deletions ochat/evaluation/grading/math_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def normalize_answer(answer: Optional[str]) -> Optional[str]:
answer = answer.strip()
try:
# Remove enclosing `\text{}`.
m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
m = re.search("^\\\\text\\{(?P<text>.+?)\\}$", answer)
if m is not None:
answer = m.group("text").strip()
return _strip_string(answer)
Expand Down Expand Up @@ -126,7 +126,7 @@ def _strip_string(string):

# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace(r"\%", "")

# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
Expand Down
12 changes: 6 additions & 6 deletions ochat/evaluation/match_answer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
import ast
import re

from ochat.evaluation.grading.math_grader import grade_answer

Expand Down Expand Up @@ -50,7 +50,7 @@ def _last_boxed_only_string(string):
break

i += 1

if left_brace_idx is None or right_brace_idx is None:
return None

Expand Down Expand Up @@ -119,7 +119,7 @@ def fs_cothub_gsm8k_match_answer(task_data, response):
# CoT hub match answer for GSM8k, match last numeric value
# https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/gpt3.5turbo_gsm8k_complex.ipynb

pattern = '\d*\.?\d+'
pattern = r'\d*\.?\d+'
pred = re.findall(pattern, response)
if len(pred) >= 1:
return True, pred[-1]
Expand All @@ -135,7 +135,7 @@ def fs_cothub_mmlu_match_answer(task_data, response):
return False, "(C)"
else:
ans = ans_line[-1].strip()

options = ['(A)', '(B)', '(C)', '(D)']
for option in options:
if option in ans:
Expand Down Expand Up @@ -174,12 +174,12 @@ def _try_match(content, prefix, entrypoint):
include_prefix = humaneval_task['prompt'].split('def')[0].strip() + "\n\n"

result = _try_match(response, include_prefix, humaneval_task["entry_point"])
if result:
if result:
return True, {"task_id": humaneval_task["task_id"], "completion": result}

# If fail then match with function signature
result = _try_match(response, humaneval_task["prompt"], humaneval_task["entry_point"])
if result:
if result:
return True, {"task_id": humaneval_task["task_id"], "completion": result}

return False, {"task_id": humaneval_task["task_id"], "completion": response}
Expand Down
17 changes: 8 additions & 9 deletions ochat/evaluation/run_eval.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from typing import Optional
import argparse
import os
import asyncio
import os
from glob import glob
from typing import Optional

import orjson
import openai
from tqdm import tqdm
import orjson
from openai.error import RateLimitError, ServiceUnavailableError
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
from vllm import LLM, SamplingParams

from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from tqdm import tqdm
from transformers.utils.hub import cached_file
from vllm import LLM, SamplingParams

from ochat.evaluation.match_answer import MATCH_ANSWER_FUNCTION
from ochat.config import MODEL_CONFIG_MAP
from ochat.evaluation.match_answer import MATCH_ANSWER_FUNCTION


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(20), retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError, )))
Expand Down Expand Up @@ -46,7 +45,7 @@ async def chat_completion_thread(model, progress_bar, queue):
e = e._exception

print(type(e), str(e))

# Progress
progress_bar.update()

Expand Down
2 changes: 1 addition & 1 deletion ochat/evaluation/view_results.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import os
from glob import glob
from pathlib import Path

import orjson
import pandas as pd
from glob import glob


def view_results(result_path: str):
Expand Down
6 changes: 3 additions & 3 deletions ochat/experimental/generate_dataset_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.json --tokenizer-name HF_REPO_NAME --out-dir .
"""

from typing import Optional
from dataclasses import dataclass
import argparse
import json
import os
import random
from dataclasses import dataclass
from typing import Optional

import numpy as np
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from ray.util.multiprocessing import Pool
from transformers.trainer_pt_utils import LabelSmoother


@dataclass
Expand Down
9 changes: 4 additions & 5 deletions ochat/models/unpadded_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@
import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.utils import logging

try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.bert_padding import pad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
print ("FlashAttention not found. Install it if you need to train models.")

Expand Down Expand Up @@ -313,7 +312,7 @@ def forward(
else:
nz_hidden_states = decoder_layer(
cos_sin,

nz_hidden_states,
nz_position_ids,
cu_seqlens,
Expand Down Expand Up @@ -355,7 +354,7 @@ def set_decoder(self, decoder):

def get_decoder(self):
return self.model

def forward(
self,
# Unpadded inputs
Expand Down
9 changes: 4 additions & 5 deletions ochat/models/unpadded_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@
import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.utils import logging

try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.bert_padding import pad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
print ("FlashAttention not found. Install it if you need to train models.")

Expand Down Expand Up @@ -313,7 +312,7 @@ def forward(
else:
nz_hidden_states = decoder_layer(
cos_sin,

nz_hidden_states,
nz_position_ids,
cu_seqlens,
Expand Down Expand Up @@ -352,7 +351,7 @@ def set_decoder(self, decoder):

def get_decoder(self):
return self.model

def forward(
self,
# Unpadded inputs
Expand Down
2 changes: 1 addition & 1 deletion ochat/scripts/hf_add_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

import transformers
import torch
import transformers


def add_tokens_to_embedding(added_special_tokens, embedding):
Expand Down
4 changes: 2 additions & 2 deletions ochat/serving/async_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ray

from ochat.config import Message, Conversation
from ochat.config import Conversation, Message


@ray.remote
Expand Down Expand Up @@ -38,7 +38,7 @@ def tokenize(self, messages, condition, enable_sys_prompt=False):
tokens, _ = self.conv_template.tokenize_conversations([Conversation(items=items, system=system_message, condition=condition)],
inference=True)
return tokens[0]

def get_eot_tokens(self):
assert len(self.conv_template.eot_tokens_) == 1

Expand Down
1 change: 0 additions & 1 deletion ochat/serving/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field

from vllm.utils import random_uuid


Expand Down
Loading