Skip to content

Commit

Permalink
Merge pull request #423 from VondrakMar/preprocess_parse
Browse files Browse the repository at this point in the history
removed wront argument parsing in preprocess_data
  • Loading branch information
ilyes319 authored May 15, 2024
2 parents 6d7b5ed + 3cc7c3d commit 81f4f8c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def mace_mp(
mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)
d3_calc = None
if dispersion:
gh_url = "https://github.com/pfnet-research/torch-dftd"
try:
Expand Down
5 changes: 2 additions & 3 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,15 @@ def main() -> None:
This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading
"""
args = tools.build_default_arg_parser().parse_args()
args = tools.build_preprocess_arg_parser().parse_args()
run(args)


def run(args: argparse.Namespace) -> None:
def run(args: argparse.Namespace):
"""
This script loads an xyz dataset and prepares
new hdf5 file that is ready for training with on-the-fly dataloading
"""
args = tools.build_preprocess_arg_parser().parse_args()

# Setup
tools.set_seeds(args.seed)
Expand Down
5 changes: 3 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def run(args: argparse.Namespace) -> None:
)
model_config_foundation["atomic_energies"] = atomic_energies
args.model = "FoundationMACE"
model_config = model_config_foundation # pylint
else:
logging.info("Building model")
if args.num_channels is not None and args.max_L is not None:
Expand Down Expand Up @@ -584,8 +585,8 @@ def run(args: argparse.Namespace) -> None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
logging.info(f"Setting start swa to {args.start_swa}")
if args.loss == "forces_only":
logging.info("Can not select swa with forces only loss.")
elif args.loss == "virials":
raise ValueError("Can not select swa with forces only loss.")
if args.loss == "virials":
loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
Expand Down

0 comments on commit 81f4f8c

Please sign in to comment.