Skip to content

Commit

Permalink
Fix computation of when learning should be retried with lower learnin…
Browse files Browse the repository at this point in the history
…g rate. (#115)

* Fix elbo_fail_fraction testing.
* Force newer pytorch version.
  • Loading branch information
alecw authored Nov 29, 2022
1 parent 7a66919 commit d82893c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion REQUIREMENTS-DOCKER.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ scipy
tables
pandas
pyro-ppl>=0.3.2
torch
torch>=1.9.0
scikit-learn
matplotlib
4 changes: 2 additions & 2 deletions cellbender/remove_background/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ def add_subparser_args(subparsers: argparse) -> argparse:
"do not exceed 1e-3). (default: %(default)s)")
subparser.add_argument("--final-elbo-fail-fraction", type=float,
help="Training is considered to have failed if "
"final_training_ELBO >= best_training_ELBO*(1+FINAL_ELBO_FAIL_FRACTION). "
"(best_test_ELBO - final_test_ELBO)/(best_test_DLBO - initial_train_ELBO) > FINAL_ELBO_FAIL_FRACTION."
"(default: do not fail training based on final_training_ELBO)")
subparser.add_argument("--epoch-elbo-fail-fraction", type=float,
help="Training is considered to have failed if "
"current_epoch_training_ELBO >= previous_epoch_training_ELBO*(1+EPOCH_ELBO_FAIL_FRACTION). "
"(previous_epoch_test_ELBO - current_epoch_test_ELBO)/(previous_epoch_test_ELBO - initial_train_ELBO) > EPOCH_ELBO_FAIL_FRACTION."
"(default: do not fail training based on epoch_training_ELBO)")
subparser.add_argument("--num-training-tries", type=int, default=1,
help="Number of times to attempt to train the model. Each subsequent "
Expand Down
15 changes: 9 additions & 6 deletions cellbender/remove_background/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,25 @@ def run_training(model: RemoveBackgroundPyroModel,
logging.info("[epoch %03d] average test loss: %.4f"
% (epoch, total_epoch_loss_test))
if epoch_elbo_fail_fraction is not None and len(test_elbo) > 1 and \
-test_elbo[-1] >= -test_elbo[-2] * (1 + epoch_elbo_fail_fraction):
test_elbo[-1] < test_elbo[-2] and \
(test_elbo[-2] - test_elbo[-1])/(test_elbo[-2] - train_elbo[0]) > epoch_elbo_fail_fraction:
logging.info(
"Training failed because this test loss (%.4f) exceeds previous test loss(%.4f) by >= %.2f%%" %
(test_elbo[-1], test_elbo[-2], 100*epoch_elbo_fail_fraction))
"Training failed because this test loss (%.4f) exceeds previous test loss(%.4f) by >= %.2f%%, "
"relative to initial train loss %.4f" ,
test_elbo[-1], test_elbo[-2], 100*epoch_elbo_fail_fraction, train_elbo[0])
succeeded = False
break

logging.info("Inference procedure complete.")

if succeeded and final_elbo_fail_fraction is not None and len(test_elbo) > 1:
best_test_elbo = max(test_elbo)
if -test_elbo[-1] >= -best_test_elbo * (1 + final_elbo_fail_fraction):
if test_elbo[-1] < best_test_elbo and \
(best_test_elbo - test_elbo[-1])/(best_test_elbo - train_elbo[0]) > final_elbo_fail_fraction:
logging.info(
"Training failed because final test loss (%.4f) exceeds "
"best test loss(%.4f) by >= %.2f%%" %
(test_elbo[-1], best_test_elbo, 100*final_elbo_fail_fraction))
"best test loss(%.4f) by >= %.2f%%, relative to initial train loss %.4f",
test_elbo[-1], best_test_elbo, 100*final_elbo_fail_fraction, train_elbo[0])
succeeded = False

# Exception allows program to continue after ending inference prematurely.
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ENV PATH=/home/user/miniconda/bin:$PATH
ENV CONDA_AUTO_UPDATE_CONDA=false

RUN conda install -y pytorch torchvision cudatoolkit -c pytorch \
RUN conda install -y "pytorch>=1.9.0" torchvision cudatoolkit -c pytorch \
&& conda install -y -c anaconda pytables \
&& conda clean -ya

Expand Down

0 comments on commit d82893c

Please sign in to comment.