From d463649c34a68c779088aaef40d547340fb2e94b Mon Sep 17 00:00:00 2001 From: "Tianyi (Alex) Qiu" Date: Sat, 7 Dec 2024 10:51:56 -0800 Subject: [PATCH] fix(benchmarking): model initialization --- benchmark/framework.py | 6 ++++-- src/abstractions/model.py | 2 +- src/evaluation/utils.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmark/framework.py b/benchmark/framework.py index 6fc649d..3ce9e66 100644 --- a/benchmark/framework.py +++ b/benchmark/framework.py @@ -59,7 +59,7 @@ def reset(self, **kwargs) -> None: self.model_list = [] for i in range(22): try: - model_name = "%db-C%03d-instruct" % (self.model_size, i) + model_name = "%dB-C%03d-instruct" % (self.model_size, i) self.model_list.append( Model( model_name=model_name, @@ -314,7 +314,9 @@ def reset(self, **kwargs) -> None: ) for i in range(22): try: - kwargs["model_name"] = "%db-C%03d-instruct" % (self.model_size, i) + kwargs["model_name"] = "%dB-C%03d-instruct" % (self.model_size, i) + self.current_model = Model(kwargs["model_name"]) + break except: pass diff --git a/src/abstractions/model.py b/src/abstractions/model.py index dcd2359..29cc6d4 100644 --- a/src/abstractions/model.py +++ b/src/abstractions/model.py @@ -1067,7 +1067,7 @@ def __evaluate_fast(self, logprobs = True) -> np.ndarray: ) as f: json.dump(raw_stats, f) print("raw results saved") - vec = calculate_model(experiment_directory, self.model_name, logprob = logprob) + vec = calculate_model(experiment_directory, self.model_name, logprobs) return vec def __evaluate_slow_moralchoice(self) -> np.ndarray: diff --git a/src/evaluation/utils.py b/src/evaluation/utils.py index 60a0d30..a9ced9f 100644 --- a/src/evaluation/utils.py +++ b/src/evaluation/utils.py @@ -280,7 +280,8 @@ def _collect(output_data): output[s_id][q_type][-1] += np.exp(x) else: invalid += 1 - print(invalid) + + print(f"{invalid} out of {len(output_data)} entries are invalid") return output