Skip to content

Commit

Permalink
fix ruff issues + linting
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Jan 27, 2025
1 parent cde2998 commit 72f0c26
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 16 deletions.
9 changes: 5 additions & 4 deletions src/scportrait/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from scportrait.pipeline._utils.helper import read_config


class Logable:
"""Create log entries.
Expand Down Expand Up @@ -179,13 +180,13 @@ class ProcessingStep(Logable):
def __init__(
self,
config,
directory = None,
project_location = None,
directory=None,
project_location=None,
debug=False,
overwrite=False,
project=None,
filehandler=None,
from_project:bool = False,
from_project: bool = False,
):
super().__init__(directory=directory)

Expand All @@ -205,7 +206,7 @@ def __init__(
if isinstance(config, str):
config = read_config(config)
if self.__class__.__name__ in config.keys():
self.config = config[self.__class__.__name__ ]
self.config = config[self.__class__.__name__]
else:
self.config = config
else:
Expand Down
3 changes: 3 additions & 0 deletions src/scportrait/pipeline/_utils/helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import TypeVar

import yaml

T = TypeVar("T")


def read_config(config_path: str) -> dict:
with open(config_path) as stream:
try:
Expand All @@ -11,6 +13,7 @@ def read_config(config_path: str) -> dict:
print(exc)
return config


def flatten(nested_list: list[list[T]]) -> list[T | tuple[T]]:
"""Flatten a list of lists into a single list.
Expand Down
21 changes: 10 additions & 11 deletions src/scportrait/pipeline/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from contextlib import redirect_stdout
from functools import partial as func_partial

import numpy as np
import h5py
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
Expand All @@ -18,6 +18,7 @@
from scportrait.tools.ml.datasets import HDF5SingleCellDataset
from scportrait.tools.ml.plmodels import MultilabelSupervisedModel


class _FeaturizationBase(ProcessingStep):
PRETRAINED_MODEL_NAMES = [
"autophagy_classifier",
Expand Down Expand Up @@ -175,9 +176,7 @@ def _get_nmasks(self):
try:
self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item()
except Exception as e:
raise ValueError(
f"Could not extract number of masks from HDF5 file. Error: {e}"
) from e
raise ValueError(f"Could not extract number of masks from HDF5 file. Error: {e}") from e

def _general_setup(self, extraction_dir: str, return_results: bool = False):
"""Helper function to execute all setup functions that are common to all featurization steps."""
Expand Down Expand Up @@ -892,7 +891,7 @@ class based on the previous single-cell extraction. Therefore, only the second a
self.log("Started MLClusterClassifier classification.")

# perform setup
self._setup(extraction_dir = extraction_dir, return_results=return_results)
self._setup(extraction_dir=extraction_dir, return_results=return_results)

self.dataloader = self.generate_dataloader(
extraction_dir,
Expand Down Expand Up @@ -975,8 +974,8 @@ def _load_models(self):
memory_usage = self._get_gpu_memory_usage()
self.log(f"GPU memory usage after loading models: {memory_usage}")

def _setup(self, extraction_dir: str):
self._general_set(extraction_dir=extraction_dir)
def _setup(self, extraction_dir: str, return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._get_model_specs()
self._setup_transforms()

Expand All @@ -988,7 +987,7 @@ def _setup(self, extraction_dir: str):

self._load_models()

def process(self, extraction_dir:str, size:int = 0, return_results:bool = False):
def process(self, extraction_dir: str, size: int = 0, return_results: bool = False):
"""
Function called to perform classification on the provided HDF5 dataset.
Expand Down Expand Up @@ -1335,12 +1334,12 @@ def __init__(self, *args, **kwargs):

self.channel_selection = None # ensure that all images are passed to the function

def _setup(self, extraction_dir:str, return_results:bool):
def _setup(self, extraction_dir: str, return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._setup_transforms()
self._get_channel_specs()

def process(self, extraction_dir: str, size: int =0, return_results: bool = False):
def process(self, extraction_dir: str, size: int = 0, return_results: bool = False):
"""
Perform featurization on the provided HDF5 dataset.
Expand Down Expand Up @@ -1453,7 +1452,7 @@ def _setup_channel_selection(self):
self.channel_selection = [0, self.channel_selection]
return

def _setup(self, extraction_dir:str, return_results:bool):
def _setup(self, extraction_dir: str, return_results: bool):
self._general_setup(extraction_dir=extraction_dir, return_results=return_results)
self._setup_channel_selection()
self._setup_transforms()
Expand Down
2 changes: 1 addition & 1 deletion src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _setup_featurization_f(self, featurization_f):
self.featurization_directory,
project_location=self.project_location,
debug=self.debug,
overwrite=False, #this needs to be set to false as the featurization step should not remove previously created features
overwrite=False, # this needs to be set to false as the featurization step should not remove previously created features
project=self,
filehandler=self.filehandler,
from_project=True,
Expand Down

0 comments on commit 72f0c26

Please sign in to comment.