diff --git a/main.py b/main.py index a884c0b..810140b 100644 --- a/main.py +++ b/main.py @@ -29,7 +29,7 @@ from sklearn.metrics import cohen_kappa_score from utils import * -from utils.misc import progress_bar +from utils.misc import progress_bar, save_current_code import models @@ -46,6 +46,7 @@ def main(config: DictConfig): os.makedirs(save_config_path, exist_ok=True) with open(os.path.join(save_config_path, "README.md"), 'w+') as f: f.write(OmegaConf.to_yaml(config, resolve=True)) + save_current_code(save_config_path) solver = Solver(config) return solver.run() diff --git a/utils/misc.py b/utils/misc.py index 054cef6..cca9f8f 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -1,6 +1,8 @@ import sys import time import os +import subprocess +import zipfile TOTAL_BAR_LENGTH = 80 LAST_T = time.time() @@ -90,4 +92,17 @@ def format_time(seconds): i += 1 if f == '': f = '0ms' - return f \ No newline at end of file + return f + +def save_current_code(path: str): + print(f"Saving current code to {path}") + files_in_repo = subprocess.run(['git', 'ls-tree', '--full-tree', '-r', '--name-only', 'HEAD'], + stdout=subprocess.PIPE).stdout.decode("utf-8").split("\n") + root = subprocess.run(['git', 'rev-parse', '--show-toplevel'], stdout=subprocess.PIPE).stdout.decode( + "utf-8").rstrip('\n') + with zipfile.ZipFile(os.path.join(path, "files.zip"), "w", zipfile.ZIP_DEFLATED) as z: + for file in files_in_repo: + file_path = os.path.join(root, file) + if os.path.isfile(file_path): + print(file) + z.write(file_path, file)