Skip to content

Commit

Permalink
ensemble pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
BaptisteUrgell committed Jun 24, 2024
1 parent d178b0a commit 936c852
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 9 deletions.
3 changes: 2 additions & 1 deletion configs/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ data:
mean: [ 0.4175, 0.3497, 0.2607 ]
std: [ 0.2044, 0.1909, 0.1867 ]
class_labels: { 0: 'Background', 1: 'Inertinite', 2: 'Vitrinite', 3: 'Liptinite' }
label_weights: { 'Background': 0.035, 'Inertinite': 1.373, 'Vitrinite': 0.376, 'Liptinite': 2.216 }
label_weights: { 'Background': 0.035, 'Inertinite': 1.373, 'Vitrinite': 0.376, 'Liptinite': 2.216 }
num_labels: 4
2 changes: 1 addition & 1 deletion src/models/semi_supervised/mask2former/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def load_model(config: Config, map_location=None):
lightning = Mask2FormerLightning(config)
else:
path_checkpoint = os.path.join(config.path.models, config.checkpoint)
lightning = Mask2FormerLightning.load_from_checkpoint(path_checkpoint, config=config, map_location=map_location)
lightning = Mask2FormerLightning.load_from_checkpoint(path_checkpoint, config=config, map_location=map_location, strict=False)

return lightning

Expand Down
84 changes: 84 additions & 0 deletions src/submissions/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import numpy as np
from src.utils.cls import Config
from tqdm.autonotebook import tqdm
from glob import glob
from src.utils import func

def main():
config = func.load_config('main', loading='object')
submissions = [
'submissions/fluent-water-971-mlmyc2ql/fluent-water-971-mlmyc2ql-ckpt-micro-tiling-384-tta-max',
'submissions/atomic-sweep-62-iecskwiv/atomic-sweep-62-iecskwiv-ckpt-spv-v1-tiling-384-tta-max',
'submissions/avid-sweep-31-kyxcee1p/avid-sweep-31-kyxcee1p-ckpt-spv-v1-tiling-384-tta-max'
]

weights = [
0.5471,
0.5238,
0.5032,
]
# weights = None


ensembler = Ensembler(config, submissions, weights)
ensembler.build()

class Ensembler:
def __init__(self, config: Config, submission_folders: list[str], weights: list[float] = None):
self.config = config
self.submission_folders = submission_folders
if weights is None:
weights = [1.] * len(self.submission_folders)
self.weights = weights

def __call__(self, pred_masks: list[np.ndarray], weights: list[float] = None) -> np.ndarray:
if weights is None:
weights = [1.] * len(pred_masks)

processed_masks = [self.process_mask(mask, weight) for mask, weight in zip(pred_masks, weights)]
ensemble_mask = self.merge_masks(processed_masks)

return ensemble_mask

def build(self):
pathname = os.path.join(self.config.path.data.raw.test.unlabeled, '*.JPG')
submission_name = '-'.join([
os.path.basename(submission_folder).split('-')[3]
for submission_folder in self.submission_folders
])
submission_folder = os.path.join(self.config.path.submissions, submission_name)
os.makedirs(submission_folder, exist_ok=True)

for image_path in tqdm(glob(pathname)):
mask_pred_name = os.path.basename(image_path).replace('.JPG', '_pred.npy')
submission_path = os.path.join(submission_folder, mask_pred_name)
pred_masks = []
for prediction_folder, weight in zip(self.submission_folders, self.weights):
mask_path = os.path.join(prediction_folder, mask_pred_name)
mask = np.load(mask_path)
processed_mask = self.process_mask(mask, weight)
pred_masks.append(processed_mask)

ensemble_mask = self.merge_masks(pred_masks)
np.save(submission_path, ensemble_mask)


def process_mask(self, mask: np.ndarray, weight: float) -> np.ndarray:
classes_mask = np.eye(self.config.data.num_labels)[mask]
prob_mask = classes_mask * weight

return prob_mask

def merge_masks(self, pred_masks: list[np.ndarray]) -> np.ndarray:
merged_mask = np.zeros_like(pred_masks[0])
for mask in pred_masks:
merged_mask += mask

merged_mask = np.argmax(merged_mask, axis=-1)

return merged_mask


if __name__ == '__main__':
main()
13 changes: 6 additions & 7 deletions src/submissions/make_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,30 @@
from tqdm import tqdm

from src.submissions.model import SS2InferenceModel
from src.submissions.tta import TestTimeAugmenter
from src.utils import func
from src.utils.cls import Config
from src.utils.cls import Config, TrainingMode


def main():
base_config = func.load_config('main')
wandb_run = func.get_run('mlmyc2ql') # jaom0qef
wandb_run = func.get_run('kyxcee1p') # jaom0qef
submission_name = f'{wandb_run.name}-{wandb_run.id}'
device = 'cuda:1'
device = 'cuda:0'
tile_sizes = [
wandb_run.config['tile_size'],
# 512
]
checkpoint_types = [
'micro',
'macro',
'spv-v1',
]
tta_ks = [
1,
# 1,
'max',
]

for checkpoint_type, tile_size, tta_k in product(checkpoint_types, tile_sizes, tta_ks):
wandb_run.config['checkpoint'] = f'{wandb_run.name}-{wandb_run.id}-{checkpoint_type}.ckpt'
wandb_run.config['mode'] = TrainingMode.SEMI_SUPERVISED
config = Config(base_config, wandb_run.config)
pathname = os.path.join(config.path.data.raw.test.unlabeled, '*.JPG')

Expand Down

0 comments on commit 936c852

Please sign in to comment.