Skip to content

Commit

Permalink
Merge pull request #6 from automl/cleanup
Browse files Browse the repository at this point in the history
Cleanup
rheasukthanker authored May 29, 2024
2 parents c17074e + c62974d commit a9378e2
Showing 24 changed files with 398 additions and 56 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@ experiments/*
logs/
*pdf
*out

*png
*pdf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
30 changes: 20 additions & 10 deletions baselines/gpt_objective_2d.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
search_spaces,
)
from typing import Dict, Any
import torch

report = Reporter()

@@ -24,6 +25,7 @@ def objective(
device: str,
search_space: str,
surrogate_type: str,
type: str,
objective: str,
) -> Reporter:
max_layers = get_max_min_stats(search_space)["max_layers"]
@@ -34,22 +36,28 @@ def objective(
arch_feature_map_predictor = normalize_arch_feature_map(
arch_feature_map, search_space
)

device_run = "cuda" if torch.cuda.is_available() else "cpu"
ppl_predictor = get_ppl_predictor_surrogate(search_space)
perplexity = ppl_predictor(arch_feature_map_ppl_predictor.cuda().unsqueeze(0))
perplexity = ppl_predictor(
arch_feature_map_ppl_predictor.to(device_run).unsqueeze(0)
)
hw_predictor = get_hw_predictor_surrogate(
max_layers, search_space, device, surrogate_type, objective
max_layers, search_space, device, surrogate_type, type, objective
)
hw_metric = predict_hw_surrogate(
[arch_feature_map_predictor], hw_predictor, surrogate_type
)
ppl = perplexity.item()
ppl_norm = normalize_ppl(ppl, search_space)
if objective == "energy":
hw_metric_norm = normalize_energy(hw_metric, device, search_space)
elif objective == "latency":
hw_metric_norm = normalize_latency(hw_metric, device, search_space)
report(perplexity=ppl_norm, hw_metric=hw_metric_norm)
if objective == "energies":
hw_metric_norm = normalize_energy(
hw_metric, device, surrogate_type, type, search_space, objective
)
elif objective == "latencies":
hw_metric_norm = normalize_latency(
hw_metric, device, surrogate_type, type, search_space, objective
)
report(perplexity=ppl_norm, hw_metric=hw_metric_norm.item())


if __name__ == "__main__":
@@ -61,12 +69,13 @@ def objective(
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--surrogate_type", type=str, default="conformal_quantile")
parser.add_argument("--type", type=str, default="quantile")
parser.add_argument("--search_space", type=str, default="s")
parser.add_argument("--device", type=str, default="P100")
parser.add_argument("--device", type=str, default="v100")
parser.add_argument("--num_layers", type=int, default=12)
parser.add_argument("--embed_dim", type=int, default=768)
parser.add_argument("--bias", type=bool, default=True)
parser.add_argument("--objective", type=str, default="energy")
parser.add_argument("--objective", type=str, default="energies")
args = parser.parse_known_args()[0]
search_space = search_spaces[args.search_space]
max_layers = max(search_space["n_layer_choices"])
@@ -90,6 +99,7 @@ def objective(
sampled_config=sample_config,
search_space=args.search_space,
surrogate_type=args.surrogate_type,
type=args.type,
device=args.device,
objective=args.objective,
)
20 changes: 15 additions & 5 deletions baselines/run_nas_gpt_2d.py
Original file line number Diff line number Diff line change
@@ -70,15 +70,17 @@
parser.add_argument(
"--device",
type=str,
default="P100",
default="v100",
)
parser.add_argument("--search_space", type=str, default="s")
parser.add_argument("--surrogate_type", type=str, default="conformal_quantile")
parser.add_argument("--objective", type=str, default="energy")
parser.add_argument("--type", type=str, default="quantile")
parser.add_argument("--objective", type=str, default="energies")
args, _ = parser.parse_known_args()
search_space = search_spaces[args.search_space]
max_layers = max(search_spaces["n_layer_choices"])
max_layers = max(search_space["n_layer_choices"])
config_space = {
"type": args.type,
"search_space": args.search_space,
"surrogate_type": args.surrogate_type,
"objective": args.objective,
@@ -168,7 +170,7 @@

# Stopping criterion: We stop after `args.max_wallclock_time` seconds
# [5]
stop_criterion = StoppingCriterion(max_num_trials_finished=200)
stop_criterion = StoppingCriterion(max_num_trials_finished=50)

tuner = Tuner(
trial_backend=trial_backend,
@@ -211,14 +213,22 @@
"hw_metric": energy,
}

os.makedirs("results_correct", exist_ok=True)
os.makedirs("results_gpt_baselines_2d", exist_ok=True)
save_path = (
"results_gpt_baselines_2d/"
+ args.experiment_tag
+ "_"
+ args.method
+ "_"
+ args.device
+ "_"
+ args.search_space
+ "_"
+ args.objective
+ "_"
+ args.surrogate_type
+ "_"
+ str(args.type)
+ ".pickle"
)
with open(save_path, "wb") as f:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added hwgpt/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file added hwgpt/__pycache__/api.cpython-311.pyc
Binary file not shown.
Binary file added hwgpt/__pycache__/api_utils.cpython-311.pyc
Binary file not shown.
8 changes: 6 additions & 2 deletions hwgpt/api.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,8 @@ def compute_predictions_hw(
device: str,
surrogate_type: str = "conformal_quantile",
data_type: str = "quantile",
return_all_quantiles: bool = True,
return_all: bool = True,
return_all_quantiles: bool = False,
) -> Any:
arch_feature = get_arch_feature_map(self.config, self.search_space_name)
arch_feature = normalize_arch_feature_map(arch_feature, self.search_space_name)
@@ -151,8 +152,11 @@ def compute_predictions_hw(
[arch_feature],
surrogate,
surrogate_type,
return_all=return_all,
return_quantiles=return_all_quantiles,
)[0]
)
if not return_all_quantiles:
return predictions_hw[0]
return predictions_hw

def eval_supernet_surrogate(self) -> Dict[str, float]:
31 changes: 22 additions & 9 deletions hwgpt/predictors/hwmetric/compute_max_min_stats_energy_gpus.py
Original file line number Diff line number Diff line change
@@ -11,16 +11,16 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch HW Metric Predictor")
parser.add_argument("--device", type=str, default="a100", help="device name")
parser.add_argument("--device", type=str, default="v100", help="device name")
parser.add_argument(
"--metric",
type=str,
default="energies",
)
parser.add_argument("--search_space", type=str, default="s")
parser.add_argument("---model", type="str", default="conformal_quantile")
parser.add_argument("--type", type="str", default="quantile")
parser.add_argument("--num_quantiles", type="str", default=10)
parser.add_argument("--model", type=str, default="conformal_quantile")
parser.add_argument("--type", type=str, default="quantile")
parser.add_argument("--num_quantiles", type=str, default=9)
parser.add_argument(
"--batch-size",
type=int,
@@ -52,23 +52,32 @@
"data_collection/gpt_datasets/predictor_ckpts/hwmetric/" + str(args.model) + "/"
)
model_path = (
base_path + args.metric + "_" + args.search_space + "_" + args.device + ".pkl"
base_path
+ args.metric
+ "_"
+ str(args.type)
+ "_"
+ args.search_space
+ "_"
+ args.device
+ ".pkl"
)
with open(model_path, "rb") as f:
model = pickle.load(f)
max_config = sample_config_max(search_space)
min_config = sample_config_min(search_space)

max_feature = normalize_arch_feature_map(
get_arch_feature_map(max_config, args.search_space)
get_arch_feature_map(max_config, args.search_space), args.search_space
)
min_feature = normalize_arch_feature_map(
get_arch_feature_map(min_config, args.search_space)
get_arch_feature_map(min_config, args.search_space), args.search_space
)
lats_max = max(
predict_hw_surrogate(max_feature, model, args.model, return_quantile=True)
predict_hw_surrogate([max_feature], model, args.model, return_all=True)[0]
)
lats_min = min(
predict_hw_surrogate(min_feature, model, args.model, return_quantiles=True)
predict_hw_surrogate([min_feature], model, args.model, return_all=True)[0]
)
max_min_stats = {"max": lats_max, "min": lats_min}
model_stats_path = (
@@ -78,6 +87,10 @@
+ "_"
+ args.search_space
+ "_"
+ args.model
+ "_"
+ args.type
+ "_"
+ args.device
+ ".pkl"
)
Binary file added lib/__pycache__/utils.cpython-311.pyc
Binary file not shown.
68 changes: 40 additions & 28 deletions lib/utils.py
Original file line number Diff line number Diff line change
@@ -71,27 +71,15 @@ def get_max_min_true_metric(api, metric=str) -> Dict[str, float]:
return {"max": max_metric, "min": min_metric}


def convert_arch_to_str(arch: Dict[str, Any], scale):
def convert_arch_to_str(arch:Dict[str,Any],scale:str)->str:
str_mlp = ""
str_heads = ""
for i in range(arch["sample_n_layer"]):
str_mlp = str_mlp + str(arch["sample_mlp_ratio"][i])
str_heads = str_heads + str(arch["sample_n_head"][i])
name = (
"gpt-"
+ str(scale)
+ "-"
+ str(arch["sample_n_layer"])
+ "-"
+ str(arch["sample_embed_dim"])
+ "-"
+ str_mlp
+ "-"
+ str_heads
+ "-"
+ str(arch["sample_bias"])
)
return name
str_mlp = str_mlp+str(arch["sample_mlp_ratio"][i])+"-"
str_heads = str_heads+str(arch["sample_n_head"][i])+"-"
name = "gpt-"+str(scale)+"-"+str(arch["sample_n_layer"])+'-'+str(arch["sample_embed_dim"])+'-'+str_mlp+str_heads+str(arch["sample_bias"])
print(name)
return name


def convert_str_to_arch(arch_str: str) -> Dict[str, Any]:
@@ -194,13 +182,26 @@ def normalize_ppl(ppl: float, scale: str) -> float:
return ppl


def normalize_energy(energy: float, device: str, scale: str) -> float:
with open(
"hwmetric_predictor_ckpts/max_min_stats_energy_"
+ device
def normalize_energy(energy: float, device: str, surrogate:str, data_type:str, scale: str, metric:str) -> float:
base_path = (
"data_collection/gpt_datasets/predictor_ckpts/hwmetric/" + str(surrogate) + "/"
)
model_path = (
base_path
+ "stats_max_min_"
+ str(metric)
+"_"
+ scale
+ "_"
+ str(scale)
+ surrogate
+ "_"
+ data_type
+ "_"
+ device
+ ".pkl"
)
with open(
model_path,"rb"
) as f:
max_min_stats = pickle.load(f)
max_energy = max_min_stats["max"]
@@ -209,13 +210,24 @@ def normalize_energy(energy: float, device: str, scale: str) -> float:
return energy


def normalize_latency(latency: float, device: str, scale: str) -> float:
with open(
"hwmetric_predictor_ckpts/max_min_stats_latency_"
+ device
def normalize_latency(latency: float, device: str, surrogate:str, data_type:str, scale: str, metric:str) -> float:
base_path = "data_collection/gpt_datasets/predictor_ckpts/hwmetric/"
model_path = (
base_path
+ "stats_max_min_"
+ str(metric)
+"_"
+ scale
+ "_"
+ str(scale)
+ surrogate
+ "_"
+ data_type
+ "_"
+ device
+ ".pkl"
)
with open(
model_path,"rb"
) as f:
max_min_stats = pickle.load(f)
max_latency = max_min_stats["max"]
217 changes: 217 additions & 0 deletions plotting/plot_corr_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import pickle
import numpy as np
import matplotlib.pyplot as plt
from hwgpt.model.gpt.utils import sample_config
from lib.utils import search_spaces
from hwgpt.api import HWGPTBenchAPI
import torch

plt.rcParams["axes.grid"] = True
plt.rcParams["grid.linestyle"] = "dotted"
plt.rcParams["font.size"] = 16
# plt tight layout
plt.rcParams["figure.autolayout"] = True


class CorrScatter:
def __init__(self, search_space=str, num_archs=int):
self.search_space = search_space
self.num_archs = num_archs
self.sampled_archs = self.sample_archs()
print(len(self.sampled_archs))
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.api = HWGPTBenchAPI(search_space=self.search_space)
print("API initialized")

def sample_archs(self):
sampled_archs = []
for i in range(self.num_archs):
arch = sample_config(search_spaces[self.search_space], seed=i)
sampled_archs.append(arch)
return sampled_archs

def get_perplexity_list(self):
ppl = []
for arch in self.sampled_archs:
self.api.set_arch(arch)
ppl.append(self.api.compute_predictions_ppl())
return ppl

def get_flops_list(self):
flops = []
for arch in self.sampled_archs:
self.api.set_arch(arch)
flops.append(self.api.get_flops())

return flops

def get_params_list(self):
params = []
for arch in self.sampled_archs:
self.api.set_arch(arch)
params.append(self.api.get_params())

return params

def get_float16_memory(self):
float16_memory = []
for arch in self.sampled_archs:
self.api.set_arch(arch)
float16_memory.append(
self.api.compute_predictions_hw(
hw_metric="float16_memory",
device="rtx2080",
surrogate_type="mlp",
data_type="median",
).item()
)
return float16_memory

def get_bfloat16_memory(self):
bfloat16_memory = []
for arch in self.sampled_archs:
self.api.set_arch(arch)
bfloat16_memory.append(
self.api.compute_predictions_hw(
hw_metric="bfloat16_memory",
device="a100",
surrogate_type="mlp",
data_type="median",
).item()
)
return bfloat16_memory

def get_lat_en(
self, hw_metric: str, device: str, surrogate_type: str, data_type: str
):
lat_en = []
for arch in self.sampled_archs:
self.api.set_arch(arch)

predictions = self.api.compute_predictions_hw(
hw_metric, device, surrogate_type, data_type
)
if "quantile" in surrogate_type:
for prediction in predictions:
if hw_metric == "energies":
prediction = prediction * 1000
lat_en.append(prediction.item())
else:
if hw_metric == "energies":
predictions = predictions * 1000
lat_en.append(predictions.item())
return lat_en

def plot_corr_scatter(
self, metrics_list: list, device: str, surrogate_type: str, data_type: str
):
metrics_dict = {}
repeat_for_quantiles = "quantile" in surrogate_type
assert len(metrics_list) == 3
plt.figure(figsize=(15, 5))
for i in range(len(metrics_list)):
metric = metrics_list[i]
if metric == "perplexity":
y = self.get_perplexity_list()
# repeat each entry 9 quantiles times
if repeat_for_quantiles:
y = np.repeat(y, 9)
elif metric == "flops":
y = self.get_flops_list()
if repeat_for_quantiles:
y = np.repeat(y, 9)
elif metric == "params":
y = self.get_params_list()
if repeat_for_quantiles:
y = np.repeat(y, 9)
elif metric == "float16_memory":
y = self.get_float16_memory()
if repeat_for_quantiles:
y = np.repeat(y, 9)
elif metric == "bfloat16_memory":
y = self.get_bfloat16_memory()
if repeat_for_quantiles:
y = np.repeat(y, 9)
else:
y = self.get_lat_en(metric, device, surrogate_type, data_type)
metrics_dict[metric] = y
print("Processed", metric)
plt.subplot(1, 3, 1)
sc = plt.scatter(
metrics_dict[metrics_list[0]],
metrics_dict[metrics_list[1]],
c=metrics_dict[metrics_list[2]],
cmap="viridis",
s=4,
)
plt.colorbar(sc, label=metrics_list[2])
plt.xlabel(metrics_list[0])
plt.ylabel(metrics_list[1])
plt.subplot(1, 3, 2)
sc = plt.scatter(
metrics_dict[metrics_list[0]],
metrics_dict[metrics_list[2]],
c=metrics_dict[metrics_list[1]],
cmap="viridis",
s=4,
)
plt.colorbar(sc, label=metrics_list[1])
plt.xlabel(metrics_list[0])
plt.ylabel(metrics_list[2])
plt.subplot(1, 3, 3)
sc = plt.scatter(
metrics_dict[metrics_list[1]],
metrics_dict[metrics_list[2]],
c=metrics_dict[metrics_list[0]],
cmap="viridis",
s=4,
)
plt.colorbar(sc, label=metrics_list[0])
plt.xlabel(metrics_list[1])
plt.ylabel(metrics_list[2])
plt.suptitle(
"Trade-offs between "
+ metrics_list[0]
+ ", "
+ metrics_list[1]
+ " and "
+ metrics_list[2]
)
plt.savefig(
"corr_scatter_plots/"
+ metrics_list[0]
+ "_"
+ metrics_list[1]
+ "_"
+ metrics_list[2]
+ "_"
+ device
+ "_"
+ surrogate_type
+ "_"
+ data_type
+ ".pdf"
)
# clear subplots

plt.clf()
plt.close()


if __name__ == "__main__":
corr_scatter = CorrScatter("s", 10000)
corr_scatter.plot_corr_scatter(
["perplexity", "energies", "float16_memory"], "rtx2080", "mlp", "quantile"
)
corr_scatter.plot_corr_scatter(
["perplexity", "latencies", "bfloat16_memory"], "a100", "mlp", "quantile"
)
corr_scatter.plot_corr_scatter(
["perplexity", "energies", "latencies"], "a100", "mlp", "quantile"
)
corr_scatter.plot_corr_scatter(
["perplexity", "latencies", "energies"],
"rtx2080",
"conformal_quantile",
"quantile",
)
76 changes: 76 additions & 0 deletions plotting/plot_q_q.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from hwgpt.api import HWGPTBenchAPI
import pickle
import numpy as np
import matplotlib.pyplot as plt
from lib.utils import convert_str_to_arch
import random

plt.rcParams["axes.grid"] = True
plt.rcParams["grid.linestyle"] = "dotted"
plt.rcParams["font.size"] = 16
# plt tight layout
plt.rcParams["figure.autolayout"] = True


class HWQQPlot:
def __init__(self, search_space: str):
self.search_space = search_space
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.api = HWGPTBenchAPI(search_space=self.search_space)
stats_path = (
"data_collection/gpt_datasets/gpt_" + self.search_space + "/stats.pkl"
)
with open(stats_path, "rb") as f:
self.stats = pickle.load(f)
print("API initialized")
self.sample_arch()

def sample_arch(self):
self.sampled_arch = random.choice(list(self.stats.keys()))
self.sampled_config = convert_str_to_arch(self.sampled_arch)

def get_hw_metrics(
self, hw_metric: str, device: str, surrogate_type: str, data_type: str
):
self.api.set_arch(self.sampled_config)
return self.api.compute_predictions_hw(
hw_metric=hw_metric,
device=device,
surrogate_type=surrogate_type,
data_type=data_type,
return_all_quantiles=True,
)

def get_actual_quantiles(self, hw_metric: str, device: str, quantiles: list):
# print(list(self.stats.keys())[0])
return np.quantile(self.stats[self.sampled_arch][device][hw_metric], quantiles)

def plot_q_q(
self, hw_metric: str, device: str, surrogate_type: str, data_type: str
):
hw_quantiles = self.get_hw_metrics(hw_metric, device, surrogate_type, data_type)
actual_quantiles = self.get_actual_quantiles(
hw_metric, device, hw_quantiles.quantiles
)
hw_quantiles = hw_quantiles.results_stacked[0]
print(hw_quantiles, actual_quantiles)
plt.plot(hw_quantiles, actual_quantiles, marker="o", ls="")
x = np.linspace(
np.min((hw_quantiles.min(), actual_quantiles.min())),
np.max((hw_quantiles.max(), actual_quantiles.max())),
)
plt.plot(x, x, linestyle="--", color="black")
plt.xlabel("HW Quantiles")
plt.ylabel("Actual Quantiles")
plt.title(f"{hw_metric} Q-Q Plot")
plt.savefig(f"{hw_metric}_{device}_{surrogate_type}_{data_type}_qq_plot.pdf")
return hw_quantiles, actual_quantiles


if __name__ == "__main__":
plot = HWQQPlot("s")
hw_quantiles, actual_quantiles = plot.plot_q_q(
"energies", "a100", "conformal_quantile", "quantile"
)
print(hw_quantiles, actual_quantiles)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
numpy
torch

0 comments on commit a9378e2

Please sign in to comment.