diff --git a/predict.py b/predict.py index ead1ae0..9136879 100644 --- a/predict.py +++ b/predict.py @@ -54,9 +54,12 @@ def predict( if target_age == "default": target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] - age_transformers = [AgeTransformer(target_age=age) for age in target_ages] + else: - age_transformers = [AgeTransformer(target_age=target_age)] + target_ages = target_age.split(',') + target_ages = [int(age) for age in target_ages] + + age_transformers = [AgeTransformer(target_age=age) for age in target_ages] results = np.array(aligned_image.resize((1024, 1024))) all_imgs = [] @@ -70,7 +73,7 @@ def predict( all_imgs.append(result_image) results = np.concatenate([results, result_image], axis=1) - if target_age == "default": + if target_age == "default" or len(age_transformers) > 1: out_path = Path(tempfile.mkdtemp()) / "output.gif" imageio.mimwrite(str(out_path), all_imgs, duration=0.3) else: