Skip to content

Commit

Permalink
Fix bug with seed values in basic template and some minor other impro…
Browse files Browse the repository at this point in the history
…vements (#17)
  • Loading branch information
jteijema authored Apr 17, 2023
1 parent a85f732 commit 9b8a307
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 20 deletions.
4 changes: 2 additions & 2 deletions asreviewcontrib/makita/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def is_valid_template(fp):

def _valid_job_file(param):
ext = Path(param).suffix
if ext.lower() not in ('.sh', '.bat'):
raise argparse.ArgumentTypeError('File must have a .sh or .bat extension')
if ext.lower() not in ('.sh', '.bat', '.yaml'):
raise argparse.ArgumentTypeError('File must have a .sh, .bat, .yaml extension')
return param


Expand Down
40 changes: 23 additions & 17 deletions asreviewcontrib/makita/template_arfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def render_jobs_arfi(
check_filename_dataset(fp_dataset)

# render priors
priors = get_priors(fp_dataset, init_seed=init_seed + i, n_priors=n_priors)
priors = get_priors(fp_dataset,
init_seed=init_seed + i,
n_priors=n_priors
)

# params for single dataset
params.append(
Expand All @@ -70,22 +73,25 @@ def render_jobs_arfi(
# open template TODO@{Replace by more sustainable module}
template = ConfigTemplate(fp_template)

for s in template.scripts:
t_script = get_file(s, "script")
export_fp = Path(scripts_folder, s)
add_file(t_script, export_fp)

for s in template.docs:
t_docs = get_file(s,
"doc",
datasets=datasets,
template_name=template.name if template.name == "ARFI" else "custom", # NOQA
template_name_long=template.name_long,
template_scripts=template.scripts,
output_folder=output_folder,
job_file=job_file,
)
add_file(t_docs, s)
# check if template.script is not NoneType
if template.scripts is not None:
for s in template.scripts:
t_script = get_file(s, "script")
export_fp = Path(scripts_folder, s)
add_file(t_script, export_fp)

if template.docs is not None:
for s in template.docs:
t_docs = get_file(s,
"doc",
datasets=datasets,
template_name=template.name if template.name == "ARFI" else "custom", # NOQA
template_name_long=template.name_long,
template_scripts=template.scripts,
output_folder=output_folder,
job_file=job_file,
)
add_file(t_docs, s)

return template.render(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ asreview wordcloud {{ dataset.input_file }} -o {{ output_folder }}/simulation/{{
# Simulate runs
mkdir {{ output_folder }}/simulation/{{ dataset.input_file_stem }}/state_files
{% for run in range(dataset.n_runs) %}
asreview simulate {{ dataset.input_file }} -s {{ output_folder }}/simulation/{{ dataset.input_file_stem }}/state_files/sim_{{ dataset.input_file_stem }}_{{ run }}.asreview --init_seed {{ dataset.init_seed }} --seed {{ dataset.model_seed }}
asreview simulate {{ dataset.input_file }} -s {{ output_folder }}/simulation/{{ dataset.input_file_stem }}/state_files/sim_{{ dataset.input_file_stem }}_{{ run }}.asreview --init_seed {{ dataset.init_seed + run}} --seed {{ dataset.model_seed + run}}
asreview metrics {{ output_folder }}/simulation/{{ dataset.input_file_stem }}/state_files/sim_{{ dataset.input_file_stem }}_{{ run }}.asreview -o {{ output_folder }}/simulation/{{ dataset.input_file_stem }}/metrics/metrics_sim_{{ dataset.input_file_stem }}_{{ run }}.json
{% endfor %}

Expand Down

0 comments on commit 9b8a307

Please sign in to comment.