Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run LoRA for different out of domain datasets #3

Closed
wants to merge 11 commits into from
12 changes: 5 additions & 7 deletions finetuning/evaluation/evaluate_amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from util import get_paths # comment this and create a custom function with the same name to run amg on your data
from util import get_pred_paths, get_default_arguments, VANILLA_MODELS
from get_loaders_for_lora import RawTrafo, LabelTrafo


def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank=None):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_amg(
Expand All @@ -16,7 +17,8 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
test_image_paths,
lora_rank=lora_rank,
)
return prediction_folder

Expand All @@ -32,12 +34,8 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder):

def main():
args = get_default_arguments()
if args.checkpoint is None:
ckpt = VANILLA_MODELS[args.model]
else:
ckpt = args.checkpoint

prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder)
prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank)
eval_amg(args.dataset, prediction_folder, args.experiment_folder)


Expand Down
8 changes: 5 additions & 3 deletions finetuning/evaluation/evaluate_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from util import get_paths # comment this and create a custom function with the same name to run ais on your data
from util import get_pred_paths, get_default_arguments
from get_loaders_for_lora import RawTrafo, LabelTrafo


def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_instance_segmentation_with_decoder(
Expand All @@ -16,7 +17,8 @@ def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, c
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
test_image_paths,
lora_rank=lora_rank,
)
return prediction_folder

Expand All @@ -34,7 +36,7 @@ def main():
args = get_default_arguments()

prediction_folder = run_instance_segmentation_with_decoder_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder
args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank
)
eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder)

Expand Down
6 changes: 6 additions & 0 deletions finetuning/evaluation/get_loaders_for_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

class RawTrafo():
...

class LabelTrafo():
...
5 changes: 3 additions & 2 deletions finetuning/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data
from util import get_model, get_default_arguments

from micro_sam.util import get_sam_model
from get_loaders_for_lora import RawTrafo, LabelTrafo

def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks):
prediction_root = os.path.join(
Expand Down Expand Up @@ -42,7 +43,7 @@ def main():
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point

# get the predictor to perform inference
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
predictor = get_sam_model(model_type=args.model, checkpoint_path=args.checkpoint, lora_rank=args.lora_rank)

prediction_root = _run_iterative_prompting(
args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks
Expand Down
5 changes: 3 additions & 2 deletions finetuning/evaluation/precompute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from util import get_paths # comment this and create a custom function with the same name to execute on your data
from util import get_model, get_default_arguments

from micro_sam.util import get_sam_model
from get_loaders_for_lora import RawTrafo, LabelTrafo

def main():
args = get_default_arguments()

predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
predictor = get_sam_model(model_type=args.model, checkpoint_path=args.checkpoint, lora_rank=args.lora_rank)
embedding_dir = os.path.join(args.experiment_folder, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)

Expand Down
12 changes: 9 additions & 3 deletions finetuning/evaluation/submit_all_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def write_batch_script(
env_name, out_path, inference_setup, checkpoint, model_type,
experiment_folder, dataset_name, delay=None, use_masks=False
experiment_folder, dataset_name, delay=None, use_masks=False, lora_rank=NotImplementedError
):
"Writing scripts with different fold-trainings for micro-sam evaluation"
batch_script = f"""#!/bin/bash
Expand All @@ -23,7 +23,7 @@ def write_batch_script(
#SBATCH -t 4-00:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A gzz0001
#SBATCH -A nim00007
#SBATCH --constraint=80gb
#SBATCH --qos=96h
#SBATCH --job-name={inference_setup}
Expand Down Expand Up @@ -55,9 +55,13 @@ def write_batch_script(
# use logits for iterative prompting
if inference_setup == "iterative_prompting" and use_masks:
python_script += "--use_masks "

if lora_rank is not None:
python_script += f"--lora_rank {lora_rank} "

# let's add the python script to the bash script
batch_script += python_script
print(batch_script)

with open(_op, "w") as f:
f.write(batch_script)
Expand Down Expand Up @@ -175,7 +179,8 @@ def submit_slurm(args):
experiment_folder=experiment_folder,
dataset_name=dataset_name,
delay=None if current_setup == "precompute_embeddings" else make_delay,
use_masks=args.use_masks
use_masks=args.use_masks,
lora_rank=args.lora_rank
)

# the logic below automates the process of first running the precomputation of embeddings, and only then inference.
Expand Down Expand Up @@ -219,6 +224,7 @@ def main(args):

# ask for a specific experiment
parser.add_argument("-s", "--specific_experiment", type=str, default=None)
parser.add_argument("--lora_rank", type=int, default=None)

args = parser.parse_args()
main(args)
Loading
Loading