Skip to content

Commit

Permalink
added gpt support
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Sep 12, 2024
1 parent c3e7330 commit a3106c8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 37 deletions.
3 changes: 2 additions & 1 deletion eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from generator.diffcharge.diffusion import DDPM
from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS
from generator.gan.acgan import ACGAN
from generator.llm.llm import HF
from generator.llm.llm import HF, GPT
from generator.options import Options

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -301,6 +301,7 @@ def get_trained_model_for_user(self, model_name: str, user_dataset: Any) -> Any:
"diffusion_ts": Diffusion_TS,
"mistral": lambda opt: HF("mistralai/Mistral-7B-Instruct-v0.2"),
"llama": lambda opt: HF("meta-llama/Meta-Llama-3.1-8B"),
"gpt": lambda opt: GPT("gpt-4o")
}

if model_name in model_dict:
Expand Down
4 changes: 2 additions & 2 deletions generator/llm/gpt_prompt_template.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"system_message": "You are a helpful assistant that performs synthetic time series generation for household electrical load daily usage profiles. The user will provide an example time series for a month and weekday combination and you will generate a synthetic time series sampled from the same distribution. The time series is represented by decimal strings separated by commas.",
"user_message": "Please generate a time series similar in shape to the following sequence without producing any additional text. Do not say anything like 'the synthetic time series is', just return the sequence of 96 values. The weekday and month are provided in the following, as well as an example time series of 96 values:\n"
"system_message": "You are a helpful assistant that generates synthetic time series data for household electrical load daily usage profiles. The user will provide an example time series for a month and weekday combination and you will generate a synthetic time series sampled from the same distribution. The time series is represented by decimal strings separated by commas.",
"user_message": "Please generate a time series similar in shape to the following sequence without producing any additional text. Do not say anything like 'the synthetic time series is', just return a sequence of EXACTLY 96 values. The example time series of 96 values looks as follows:"
}
54 changes: 21 additions & 33 deletions generator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
HF_PROMPT_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "hf_prompt_template.json"
)

GPT_PROMPT_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "gpt_prompt_template.json"
)

VALID_NUMBERS = list("0123456789")
DEFAULT_BOS_TOKEN = "<|begin_of_text|>"
DEFAULT_EOS_TOKEN = "<|end_of_text|>"
Expand Down Expand Up @@ -53,9 +58,8 @@ def __init__(self, name=DEFAULT_MODEL, sep=","):

self.tokenizer.pad_token = (
self.tokenizer.eos_token
) # Indicate end of the time series
)

# Define invalid tokens (tokens that are not digits or commas)
valid_tokens = [
self.tokenizer.convert_tokens_to_ids(str(digit)) for digit in VALID_NUMBERS
]
Expand All @@ -69,23 +73,16 @@ def __init__(self, name=DEFAULT_MODEL, sep=","):

self.invalid_tokens = [[i] for i in range(vocab_size) if i not in valid_tokens]

# Load the model
self.model = AutoModelForCausalLM.from_pretrained(
self.name, device_map="auto", torch_dtype=torch.bfloat16
)
self.model.eval()

def load_prompt_template(self):
"""Load the prompt template from a JSON file."""
with open(HF_PROMPT_PATH) as f:
template = json.load(f)["prompt_template"]
return template

def generate_timeseries(
self, example_ts, length=96, temp=1, top_p=1, raw=False, samples=1, padding=0
):
"""Generate a time series forecast."""
template = self.load_prompt_template()
template = load_prompt_template(HF_PROMPT_PATH)
prompt = template.format(length=length, example_ts=example_ts)

tokenized_input = self.tokenizer([prompt], return_tensors="pt").to("cuda")
Expand All @@ -102,7 +99,7 @@ def generate_timeseries(
temperature=temp,
top_p=top_p,
renormalize_logits=True,
bad_words_ids=self.invalid_tokens, # Ensure no invalid tokens (only digits and commas)
bad_words_ids=self.invalid_tokens,
num_return_sequences=samples,
)

Expand All @@ -120,7 +117,7 @@ def generate_timeseries(
values = response.split(self.sep)
values = [
v.strip() for v in values if v.strip().replace(".", "", 1).isdigit()
] # Remove invalid entries
]
processed_responses.append(self.sep.join(values))

return processed_responses
Expand Down Expand Up @@ -178,35 +175,21 @@ class GPT:
sep (str): String to separate each element in values. Default to ','.
"""

def __init__(self, model_name="gpt-4", sep=","):
def __init__(self, model_name="gpt-4o", sep=","):
self.model_name = model_name
self.sep = sep
self.client = OpenAI()

# Load prompt templates
self.zero_shot_prompt = self.load_prompt_template(
"gpt_system_prompt_zero_shot.txt"
)
self.one_shot_prompt = self.load_prompt_template(
"gpt_system_prompt_one_shot.txt"
)

def load_prompt_template(self, filename: str) -> str:
"""Load the prompt template from a file."""
current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(current_dir, "..", "template", filename)
with open(file_path, "r") as f:
return f.read()
self.prompt = load_prompt_template(GPT_PROMPT_PATH)

def generate_timeseries(
self, example_ts: str, length: int = 96, temp: float = 1, top_p: float = 1
) -> List[str]:
"""Generate a time series forecast."""
"""Generate synthetic time series data."""
messages = [
{"role": "system", "content": self.zero_shot_prompt},
{"role": "system", "content": self.prompt["system_message"]},
{
"role": "user",
"content": f"Generate a time series of exactly {length} comma-separated values similar to this example: {example_ts}",
"content": self.prompt["user_message"] + example_ts,
},
]

Expand All @@ -218,7 +201,6 @@ def generate_timeseries(
values = generated_ts.split(self.sep)
values = [v.strip() for v in values if v.strip().replace(".", "", 1).isdigit()]

# Ensure exactly 'length' values are returned
if len(values) < length:
LOGGER.warning(
f"Generated {len(values)} values instead of {length}. Padding with zeros."
Expand Down Expand Up @@ -246,7 +228,6 @@ def generate(

gen_ts_dataset = []

# Create a tqdm progress bar
total_iterations = len(day_labels)
with tqdm(total=total_iterations, desc="Generating Time Series") as pbar:
for day, month in zip(day_labels, month_labels):
Expand All @@ -271,3 +252,10 @@ def generate(

gen_ts_dataset = torch.tensor(np.array(gen_ts_dataset))
return gen_ts_dataset


def load_prompt_template(path):
"""Load the prompt template from a JSON file."""
with open(path) as f:
template = json.load(f)
return template
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def evaluate_single_dataset_model(


def main():
evaluate_individual_user_models("mistral", "newyork")
evaluate_individual_user_models("gpt", "newyork")
# evaluate_single_dataset_model("diffusion_ts", "austin")


Expand Down

0 comments on commit a3106c8

Please sign in to comment.