From a3106c805fe04b5b4ffb9b04b2c04e96b0b34633 Mon Sep 17 00:00:00 2001 From: Michael Fuest Date: Thu, 12 Sep 2024 10:19:08 -0400 Subject: [PATCH] added gpt support --- eval/evaluator.py | 3 +- generator/llm/gpt_prompt_template.json | 4 +- generator/llm/llm.py | 54 ++++++++++---------------- main.py | 2 +- 4 files changed, 26 insertions(+), 37 deletions(-) diff --git a/eval/evaluator.py b/eval/evaluator.py index 4f4acbf..e60fee3 100644 --- a/eval/evaluator.py +++ b/eval/evaluator.py @@ -13,7 +13,7 @@ from generator.diffcharge.diffusion import DDPM from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS from generator.gan.acgan import ACGAN -from generator.llm.llm import HF +from generator.llm.llm import HF, GPT from generator.options import Options device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -301,6 +301,7 @@ def get_trained_model_for_user(self, model_name: str, user_dataset: Any) -> Any: "diffusion_ts": Diffusion_TS, "mistral": lambda opt: HF("mistralai/Mistral-7B-Instruct-v0.2"), "llama": lambda opt: HF("meta-llama/Meta-Llama-3.1-8B"), + "gpt": lambda opt: GPT("gpt-4o") } if model_name in model_dict: diff --git a/generator/llm/gpt_prompt_template.json b/generator/llm/gpt_prompt_template.json index 2adc556..76e0b8b 100644 --- a/generator/llm/gpt_prompt_template.json +++ b/generator/llm/gpt_prompt_template.json @@ -1,4 +1,4 @@ { - "system_message": "You are a helpful assistant that performs synthetic time series generation for household electrical load daily usage profiles. The user will provide an example time series for a month and weekday combination and you will generate a synthetic time series sampled from the same distribution. The time series is represented by decimal strings separated by commas.", - "user_message": "Please generate a time series similar in shape to the following sequence without producing any additional text. Do not say anything like 'the synthetic time series is', just return the sequence of 96 values. The weekday and month are provided in the following, as well as an example time series of 96 values:\n" + "system_message": "You are a helpful assistant that generates synthetic time series data for household electrical load daily usage profiles. The user will provide an example time series for a month and weekday combination and you will generate a synthetic time series sampled from the same distribution. The time series is represented by decimal strings separated by commas.", + "user_message": "Please generate a time series similar in shape to the following sequence without producing any additional text. Do not say anything like 'the synthetic time series is', just return a sequence of EXACTLY 96 values. The example time series of 96 values looks as follows:" } diff --git a/generator/llm/llm.py b/generator/llm/llm.py index bcea42c..7d74218 100644 --- a/generator/llm/llm.py +++ b/generator/llm/llm.py @@ -14,6 +14,11 @@ HF_PROMPT_PATH = os.path.join( os.path.dirname(os.path.abspath(__file__)), "hf_prompt_template.json" ) + +GPT_PROMPT_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "gpt_prompt_template.json" +) + VALID_NUMBERS = list("0123456789") DEFAULT_BOS_TOKEN = "<|begin_of_text|>" DEFAULT_EOS_TOKEN = "<|end_of_text|>" @@ -53,9 +58,8 @@ def __init__(self, name=DEFAULT_MODEL, sep=","): self.tokenizer.pad_token = ( self.tokenizer.eos_token - ) # Indicate end of the time series + ) - # Define invalid tokens (tokens that are not digits or commas) valid_tokens = [ self.tokenizer.convert_tokens_to_ids(str(digit)) for digit in VALID_NUMBERS ] @@ -69,23 +73,16 @@ def __init__(self, name=DEFAULT_MODEL, sep=","): self.invalid_tokens = [[i] for i in range(vocab_size) if i not in valid_tokens] - # Load the model self.model = AutoModelForCausalLM.from_pretrained( self.name, device_map="auto", torch_dtype=torch.bfloat16 ) self.model.eval() - def load_prompt_template(self): - """Load the prompt template from a JSON file.""" - with open(HF_PROMPT_PATH) as f: - template = json.load(f)["prompt_template"] - return template - def generate_timeseries( self, example_ts, length=96, temp=1, top_p=1, raw=False, samples=1, padding=0 ): """Generate a time series forecast.""" - template = self.load_prompt_template() + template = load_prompt_template(HF_PROMPT_PATH) prompt = template.format(length=length, example_ts=example_ts) tokenized_input = self.tokenizer([prompt], return_tensors="pt").to("cuda") @@ -102,7 +99,7 @@ def generate_timeseries( temperature=temp, top_p=top_p, renormalize_logits=True, - bad_words_ids=self.invalid_tokens, # Ensure no invalid tokens (only digits and commas) + bad_words_ids=self.invalid_tokens, num_return_sequences=samples, ) @@ -120,7 +117,7 @@ def generate_timeseries( values = response.split(self.sep) values = [ v.strip() for v in values if v.strip().replace(".", "", 1).isdigit() - ] # Remove invalid entries + ] processed_responses.append(self.sep.join(values)) return processed_responses @@ -178,35 +175,21 @@ class GPT: sep (str): String to separate each element in values. Default to ','. """ - def __init__(self, model_name="gpt-4", sep=","): + def __init__(self, model_name="gpt-4o", sep=","): self.model_name = model_name self.sep = sep self.client = OpenAI() - - # Load prompt templates - self.zero_shot_prompt = self.load_prompt_template( - "gpt_system_prompt_zero_shot.txt" - ) - self.one_shot_prompt = self.load_prompt_template( - "gpt_system_prompt_one_shot.txt" - ) - - def load_prompt_template(self, filename: str) -> str: - """Load the prompt template from a file.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - file_path = os.path.join(current_dir, "..", "template", filename) - with open(file_path, "r") as f: - return f.read() + self.prompt = load_prompt_template(GPT_PROMPT_PATH) def generate_timeseries( self, example_ts: str, length: int = 96, temp: float = 1, top_p: float = 1 ) -> List[str]: - """Generate a time series forecast.""" + """Generate synthetic time series data.""" messages = [ - {"role": "system", "content": self.zero_shot_prompt}, + {"role": "system", "content": self.prompt["system_message"]}, { "role": "user", - "content": f"Generate a time series of exactly {length} comma-separated values similar to this example: {example_ts}", + "content": self.prompt["user_message"] + example_ts, }, ] @@ -218,7 +201,6 @@ def generate_timeseries( values = generated_ts.split(self.sep) values = [v.strip() for v in values if v.strip().replace(".", "", 1).isdigit()] - # Ensure exactly 'length' values are returned if len(values) < length: LOGGER.warning( f"Generated {len(values)} values instead of {length}. Padding with zeros." @@ -246,7 +228,6 @@ def generate( gen_ts_dataset = [] - # Create a tqdm progress bar total_iterations = len(day_labels) with tqdm(total=total_iterations, desc="Generating Time Series") as pbar: for day, month in zip(day_labels, month_labels): @@ -271,3 +252,10 @@ def generate( gen_ts_dataset = torch.tensor(np.array(gen_ts_dataset)) return gen_ts_dataset + + +def load_prompt_template(path): + """Load the prompt template from a JSON file.""" + with open(path) as f: + template = json.load(f) + return template \ No newline at end of file diff --git a/main.py b/main.py index 3c5ed78..bdf2f48 100644 --- a/main.py +++ b/main.py @@ -30,7 +30,7 @@ def evaluate_single_dataset_model( def main(): - evaluate_individual_user_models("mistral", "newyork") + evaluate_individual_user_models("gpt", "newyork") # evaluate_single_dataset_model("diffusion_ts", "austin")