diff --git a/examples/run_svi.py b/examples/run_svi.py index a3a905e..a55af8f 100644 --- a/examples/run_svi.py +++ b/examples/run_svi.py @@ -10,25 +10,40 @@ parser = argparse.ArgumentParser(description="Train DeepSequence with SVI.") parser.add_argument("--dataset", type=str, default="BLAT_ECOLX", help="Dataset name for fitting model.") +parser.add_argument("--neff-override", type=float, default=None, + help="Override the model Neff.") +parser.add_argument("--theta-override", type=float, default=None, + help="Override the model theta.") +parser.add_argument("--n-latent-override", type=int, default=None, + help="Override the model n_latent.") +parser.add_argument("--weights_dir", type=str, default="", help="Location of precomputed weights, if possible") +parser.add_argument("--alignments_dir", type=str, help="Location of alignments") +parser.add_argument("--seed", type=int, help="Random seed override (for model ensembling).") args = parser.parse_args() +args.dataset = args.dataset.split(".a2m")[0] + data_params = { - "dataset" : args.dataset, + "dataset" : args.dataset, + "weights_dir" : args.weights_dir, } +if args.seed is not None: + print("Using seed: {}".format(args.seed)) + model_params = { "bs" : 100, "encode_dim_zero" : 1500, "encode_dim_one" : 1500, "decode_dim_zero" : 100, - "decode_dim_one" : 500, + "decode_dim_one" : 2000, # 500 in the repo "n_latent" : 30, "logit_p" : 0.001, "sparsity" : "logit", "final_decode_nonlin": "sigmoid", "final_pwm_scale" : True, "n_pat" : 4, - "r_seed" : 12345, + "r_seed" : args.seed if args.seed is not None else 12345, "conv_pat" : True, "d_c_size" : 40 } @@ -37,13 +52,24 @@ "num_updates" : 300000, "save_progress" : True, "verbose" : True, - "save_parameters" : False, + "save_parameters" : 50000, } -if __name__ == "__main__": +if args.n_latent_override: + model_params['n_latent'] = args.n_latent_override +if __name__ == "__main__": + start_time = time.time() data_helper = helper.DataHelper(dataset=data_params["dataset"], - calc_weights=True) + working_dir='.', + calc_weights=False, # Use precomputed weights + theta=args.theta_override, + weights_dir=data_params["weights_dir"], + alignments_dir=args.alignments_dir, + ) + print("Data loaded.") + if args.neff_override: + data_helper.Neff = args.neff_override vae_model = model.VariationalAutoencoder(data_helper, batch_size = model_params["bs"], @@ -63,10 +89,17 @@ n_patterns = model_params["n_pat"], random_seed = model_params["r_seed"], ) + print("Model loaded") job_string = helper.gen_job_string(data_params, model_params) + if args.neff_override: + job_string += "_neff-" + str(args.neff_override) + if args.seed is not None: + job_string += "_seed-" + str(args.seed) + + print ("job string: ", job_string) - print (job_string) + print("Starting training") train.train(data_helper, vae_model, num_updates = train_params["num_updates"], @@ -74,5 +107,8 @@ save_parameters = train_params["save_parameters"], verbose = train_params["verbose"], job_string = job_string) + print("Training complete") vae_model.save_parameters(file_prefix=job_string) + + print("Done in " + str(time.time() - start_time) + " seconds")