From c6411e351be66c51a1f0c8806404fdba8aa80fd6 Mon Sep 17 00:00:00 2001 From: jfnavarro Date: Wed, 28 Apr 2021 13:31:58 +0200 Subject: [PATCH 1/2] Added top_criteria option to select top genes. Fixed a bug where the top n was being applied to the ST data. Fixed formatting issues and typos --- setup.py | 2 +- stsc/datasets.py | 255 ++++++++++++++------------- stsc/fit.py | 211 +++++++++++----------- stsc/models.py | 148 ++++++++-------- stsc/parser.py | 446 +++++++++++++++++++++-------------------------- stsc/progress.py | 101 ++++++----- stsc/run.py | 156 ++++++++--------- stsc/utils.py | 281 +++++++++++++---------------- 8 files changed, 763 insertions(+), 837 deletions(-) diff --git a/setup.py b/setup.py index 436e88e..f7cfc92 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ download_url='https://github.com/almaan/stereoscope/archive/v_03.tar.gz', license='MIT', packages=['stsc'], - python_requires='>3.5.0', + python_requires='>3.5.6', install_requires=[ 'torch>=1.1.0', 'numba>=0.49.0', diff --git a/stsc/datasets.py b/stsc/datasets.py index f9fec30..0dfa00c 100755 --- a/stsc/datasets.py +++ b/stsc/datasets.py @@ -2,31 +2,29 @@ import re import sys -from typing import List,Dict - +from typing import List, Dict import numpy as np import pandas as pd - - import torch as t from torch.utils.data import Dataset - import stsc.utils as utils + class CountDataHelper(object): """ Helper class for CountData class """ @classmethod - def update(self,func): - def wrapper(self,*args,**kwargs): - tmp = func(self,*args,**kwargs) + def update(self, func): + def wrapper(self, *args, **kwargs): + tmp = func(self, *args, **kwargs) self.G = int(self.cnt.shape[1]) self.M = int(self.cnt.shape[0]) self.Z = np.unique(self.lbl).shape[0] - self.libsize = self.cnt.sum(dim = 1) + self.libsize = self.cnt.sum(dim=1) return tmp + return wrapper @@ -48,10 +46,9 @@ class CountData(Dataset): @CountDataHelper.update def __init__(self, - cnt : pd.DataFrame, - lbl : pd.DataFrame = None, - )-> None: - + cnt: pd.DataFrame, + lbl: pd.DataFrame = None + ) -> None: self.cnt = cnt self.lbl = np.ones(self.cnt.shape[0]) * np.nan self.zidx = np.ones(self.cnt.shape[0]) * np.nan @@ -64,30 +61,30 @@ def __init__(self, self.lbl = lbl self.index = self.cnt.index.intersection(self.lbl.index) - self.cnt = self.cnt.loc[self.index,:] - self.lbl = self.lbl.loc[self.index].values.reshape(-1,) + self.cnt = self.cnt.loc[self.index, :] + self.lbl = self.lbl.loc[self.index].values.reshape(-1, ) # convert labels to numeric indices - tonumeric = { v:k for k,v in enumerate(np.unique(self.lbl)) } + tonumeric = {v: k for k, v in enumerate(np.unique(self.lbl))} self.zidx = np.array([tonumeric[l] for l in self.lbl]) # Sort data according to label enumeration - # to speed up element acession + # to speed up element accession srt = np.argsort(self.zidx) self.zidx = self.zidx[srt] self.lbl = self.lbl[srt] - self.cnt = self.cnt.iloc[srt,:] + self.cnt = self.cnt.iloc[srt, :] self.zidx = t.LongTensor(self.zidx.flatten().astype(np.int32)) # Convert to tensor self.cnt = t.tensor(self.cnt.values.astype(np.float32)) - self.libsize = self.cnt.sum(dim = 1) + self.libsize = self.cnt.sum(dim=1) @CountDataHelper.update def filter_genes(self, - pattern : str = None, - )-> None: + pattern: str = None + ) -> None: """ Filter genes based on regex-pattern Parameter: @@ -99,21 +96,16 @@ def filter_genes(self, and MALAT1 """ - if pattern is None: - pattern = '^RP|MALAT1' - - keep = [ re.search(pattern,x.upper()) is \ - None for x in self.genes] - - self.cnt = self.cnt[:,keep] + pattern = '^RP|MALAT1' if pattern is None else pattern + keep = [re.search(pattern, x.upper()) is None for x in self.genes] + self.cnt = self.cnt[:, keep] self.genes = self.genes[keep] @CountDataHelper.update def filter_bad(self, - min_counts : int = 0, - min_occurance : int = 0, - )-> None: - + min_counts: int = 0, + min_occurance: int = 0 + ) -> None: """Filter bad data points Parameter: @@ -130,18 +122,18 @@ def filter_bad(self, """ - row_thrs, col_thrs = min_counts,min_occurance - ridx = np.where(self.cnt.sum(dim = 1) > row_thrs)[0] - cidx = np.where((self.cnt != 0).type(t.float32).sum(dim = 0) > col_thrs)[0] + row_thrs, col_thrs = min_counts, min_occurance + ridx = np.where(self.cnt.sum(dim=1) > row_thrs)[0] + cidx = np.where((self.cnt != 0).type(t.float32).sum(dim=0) > col_thrs)[0] - self.cnt = self.cnt[ridx,:][:,cidx] + self.cnt = self.cnt[ridx, :][:, cidx] self.lbl = self.lbl[ridx] self.zidx = self.zidx[ridx].type(t.LongTensor) @CountDataHelper.update def intersect(self, - exog_genes : pd.Index, - ) -> pd.Index: + exog_genes: pd.Index + ) -> pd.Index: """Intersect genes of CountData object with external set Parameter: @@ -159,14 +151,14 @@ def intersect(self, inter = exog_genes.intersection(self.genes) inter = np.unique(inter) - keep = np.array([ self.genes.get_loc(x) for x in inter]) + keep = np.array([self.genes.get_loc(x) for x in inter]) self.genes = inter - self.cnt = self.cnt[:,keep] + self.cnt = self.cnt[:, keep] return self.genes - def unique_labels(self, - )->np.ndarray: + def unique_labels(self + ) -> np.ndarray: """Get unique labels Returns: @@ -174,16 +166,14 @@ def unique_labels(self, Array of unique cell type labels """ - _,upos = np.unique(self.zidx, return_index = True) + _, upos = np.unique(self.zidx, return_index=True) typenames = self.lbl[upos] return typenames - - def __getitem__(self, - idx: List[int], - )-> Dict: + idx: List[int] + ) -> Dict: """Get sample with specified index Parameter: @@ -199,33 +189,34 @@ def __getitem__(self, indices (gidx) """ - sample = {'x' : self.cnt[idx,:], - 'meta' : self.zidx[idx], - 'sf' : self.libsize[idx], - 'gidx' : t.tensor(idx), - } + sample = {'x': self.cnt[idx, :], + 'meta': self.zidx[idx], + 'sf': self.libsize[idx], + 'gidx': t.tensor(idx) + } return sample def __len__(self, - )-> int: + ) -> int: """Length of CountData object""" return self.M -def make_sc_dataset(cnt_pth : str, - lbl_pth : str, - topn_genes : int = None, - gene_list_pth : str = None, - filter_genes : bool = False, - lbl_colname : str = 'bio_celltype', - min_counts : int = 300, - min_cells : int = 0, - transpose : bool = False, - upper_bound : int = None, - lower_bound : int = None, - ): +def make_sc_dataset(cnt_pth: str, + lbl_pth: str, + topn_genes: int = None, + top_criteria: str = 'hgv', + gene_list_pth: str = None, + filter_genes: bool = False, + lbl_colname: str = 'bio_celltype', + min_counts: int = 300, + min_cells: int = 0, + transpose: bool = False, + upper_bound: int = None, + lower_bound: int = None + ): """ Generate CountData object for SC-data @@ -238,9 +229,12 @@ def make_sc_dataset(cnt_pth : str, lbl_pth : str path to SC label data - topn_genes : bool - number of top expressed genes to - include + topn_genes : int + number of top genes to include + + top_criteria : str + criteria to select topn_genes. Can be either + expression 'expr' or highly variable 'hgv' gene_list_pth : str gene list @@ -254,7 +248,7 @@ def make_sc_dataset(cnt_pth : str, spot/cell for it to be included min_cells : int - minimal number of occurances + minimal number of occurrences of a gene among all cells for it to be included @@ -278,85 +272,89 @@ def make_sc_dataset(cnt_pth : str, sc_ext = utils.get_extenstion(cnt_pth) - if sc_ext == 'h5ad' : - cnt,lbl = utils.read_h5ad_sc(cnt_pth, - lbl_colname, - lbl_pth, - ) + if sc_ext is 'h5ad': + cnt, lbl = utils.read_h5ad_sc(cnt_pth, + lbl_colname, + lbl_pth + ) else: - cnt = utils.read_file(cnt_pth,sc_ext) + cnt = utils.read_file(cnt_pth, sc_ext) if transpose: cnt = cnt.T lbl = utils.read_file(lbl_pth) # get labels if lbl_colname is None: - lbl = lbl.iloc[:,0] + lbl = lbl.iloc[:, 0] else: - lbl = lbl.loc[:,lbl_colname] + lbl = lbl.loc[:, lbl_colname] # match count and label data inter = cnt.index.intersection(lbl.index) if inter.shape[0] < 1: - print("[ERROR] : single cell count and annotation"\ + print("[ERROR] : single cell count and annotation" " data did not match. Exiting.", - file = sys.stderr, + file=sys.stderr ) - cnt = cnt.loc[inter,:] - lbl = lbl.loc[inter] + sys.exit(-1) + cnt = cnt.loc[inter, :] + lbl = lbl.loc[inter] - if upper_bound is not None or\ - lower_bound is not None: - cnt,lbl = utils.subsample_data(cnt, - lbl, - lower_bound, - upper_bound, - ) - - # select top N expressed genes - if topn_genes is not None: - genesize = cnt.values.sum(axis = 0) - topn_genes = np.min((topn_genes,genesize.shape[0])) + if upper_bound is not None or \ + lower_bound is not None: + cnt, lbl = utils.subsample_data(cnt, + lbl, + lower_bound, + upper_bound + ) + + # select top N genes + if topn_genes is not None and top_criteria in ['expr', 'hgv']: + # NOTE Pandas supports var and sum, this could be simplified + genesize = cnt.values.sum(axis=0) if top_criteria is 'expr' else cnt.values.var(axis=0) + topn_genes = np.min((topn_genes, genesize.shape[0])) sel = np.argsort(genesize)[::-1] sel = sel[0:topn_genes] - cnt = cnt.iloc[:,sel] + cnt = cnt.iloc[:, sel] + elif topn_genes is not None: + # NOTE top_criteria is neither hgv or expr + # we should throw an exception here or a warning + pass # only use genes in specific genes list # if specified if gene_list_pth is not None: - with open(gene_list_pth,'r+') as fopen: + with open(gene_list_pth, 'r+') as fopen: gene_list = fopen.readlines() - - gene_list = pd.Index([ x.replace('\n','') for x in gene_list ]) + gene_list = pd.Index([x.replace('\n', '') for x in gene_list]) sel = cnt.columns.intersection(gene_list) - cnt = cnt.loc[:,sel] + cnt = cnt.loc[:, sel] # create sc data set - dataset = CountData(cnt = cnt, - lbl = lbl) + dataset = CountData(cnt=cnt, lbl=lbl) # filter genes based on names if filter_genes: dataset.filter_genes() # filter data based on quality - if any([min_counts > 0,min_cells > 0]): - dataset.filter_bad(min_counts = min_counts, - min_occurance = min_cells, - ) + if any([min_counts > 0, min_cells > 0]): + dataset.filter_bad(min_counts=min_counts, + min_occurance=min_cells + ) return dataset -def make_st_dataset(cnt_pths : List[str], - topn_genes : bool = None, - min_counts : int = 0, - min_spots : int = 0, - filter_genes : bool = False, - transpose : bool = False, - )-> CountData : - +def make_st_dataset(cnt_pths: List[str], + topn_genes: int = None, + top_criteria: str = 'hgv', + min_counts: int = 0, + min_spots: int = 0, + filter_genes: bool = False, + transpose: bool = False + ) -> CountData: """ Generate CountData object for ST-data @@ -366,10 +364,14 @@ def make_st_dataset(cnt_pths : List[str], cnt_pths : List[str] list of paths to ST-data - topn_genes : bool + topn_genes : int number of top expressed genes to include in analysis + top_criteria : str + criteria to select topn_genes. Can be either + expression 'expr' or highly variable 'hgv' + min_counts : int minimal number of observed counts assigned to a specific @@ -402,16 +404,20 @@ def make_st_dataset(cnt_pths : List[str], if st_ext == "h5ad": cnt = utils.read_h5ad_st(cnt_pths) else: - cnt = utils.make_joint_matrix(cnt_pths, - transpose) + cnt = utils.make_joint_matrix(cnt_pths, transpose) - # select top N genes if specified - if topn_genes is not None: - genesize = cnt.values.sum(axis = 0) - topn_genes = np.min((topn_genes,genesize.shape[0])) + # select top N genes + if topn_genes is not None and top_criteria in ['expr', 'hgv']: + # NOTE Pandas supports var and sum, this could be simplified + genesize = cnt.values.sum(axis=0) if top_criteria is 'expr' else cnt.values.var(axis=0) + topn_genes = np.min((topn_genes, genesize.shape[0])) sel = np.argsort(genesize)[::-1] sel = sel[0:topn_genes] - cnt = cnt.iloc[:,sel] + cnt = cnt.iloc[:, sel] + elif topn_genes is not None: + # NOTE top_criteria is neither hgv or expr + # we should throw an exception here or a warning + pass dataset = CountData(cnt) @@ -420,12 +426,9 @@ def make_st_dataset(cnt_pths : List[str], dataset.filter_genes() # filter data based on quality - if any([min_counts > 0,min_spots > 0]): - dataset.filter_bad(min_counts = min_counts, - min_occurance = min_spots, + if any([min_counts > 0, min_spots > 0]): + dataset.filter_bad(min_counts=min_counts, + min_occurance=min_spots ) - return dataset - - diff --git a/stsc/fit.py b/stsc/fit.py index ec09973..8c8b0a4 100755 --- a/stsc/fit.py +++ b/stsc/fit.py @@ -4,29 +4,25 @@ from os import mkdir import os.path as osp from typing import NoReturn, Union, Dict - import torch as t from torch.utils.data import DataLoader - import numpy as np import pandas as pd - - import stsc.models as M import stsc.datasets as D import stsc.utils as utils -def fit(model : Union[M.ScModel,M.STModel], - dataset : D.CountData, - loss_tracker : utils.LossTracker, - device : t.device, - epochs : int, - learning_rate : float, - batch_size : int = None, - silent_mode : bool = False, + +def fit(model: Union[M.ScModel, M.STModel], + dataset: D.CountData, + loss_tracker: utils.LossTracker, + device: t.device, + epochs: int, + learning_rate: float, + batch_size: int = None, + silent_mode: bool = False, **kwargs ) -> None: - """Fit Model Generic function to fit models @@ -61,34 +57,36 @@ def fit(model : Union[M.ScModel,M.STModel], # move model to device model.to(device) + # define optimizer - optim = t.optim.Adam(model.parameters(), - lr = learning_rate) - # instatiate progressbar + optim = t.optim.Adam(model.parameters(), lr=learning_rate) + + # instantiate progressbar progressBar = utils.SimpleProgressBar(epochs, - silent_mode = silent_mode, - length = 20) + silent_mode=silent_mode, + length=20 + ) # use full dataset if no batch size is specified if batch_size is None: batch_size = dataset.M else: - batch_size = int(np.min((batch_size,dataset.M))) + batch_size = int(np.min((batch_size, dataset.M))) dataloader = DataLoader(dataset, - batch_size = batch_size, - shuffle = False, + batch_size=batch_size, + shuffle=False, ) # Use try/except to catch SIGINT - # for early interuption + # for early interruption try: for epoch in range(epochs): epoch_loss = 0.0 for batch in dataloader: # move batch items to device - for k,v in batch.items(): + for k, v in batch.items(): batch[k] = v.to(device) batch['x'].requires_grad = True @@ -105,32 +103,35 @@ def fit(model : Union[M.ScModel,M.STModel], # update progress bar progressBar(epoch, epoch_loss) # record loss progression - loss_tracker(epoch_loss,epoch) + loss_tracker(epoch_loss, epoch) - # newline after complettion + # newline after completion print('\n') + # write final loss loss_tracker.write_history() except KeyboardInterrupt: print(' '.join(["\n\nPress Ctrl+C again", "to interrupt whole process", - ] - ) - ) - -def fit_st_data(st_data : D.CountData, - R : pd.DataFrame, - logits : pd.DataFrame, - loss_tracker : utils.LossTracker, - device : t.device, - st_epochs : int, - learning_rate : float, - st_batch_size : int, - silent_mode : bool = False, - st_from_model : str = None, - keep_noise : bool = False, - **kwargs)->Dict[str,Union[pd.DataFrame,M.STModel]]: + ] + ) + ) + + +def fit_st_data(st_data: D.CountData, + R: pd.DataFrame, + logits: pd.DataFrame, + loss_tracker: utils.LossTracker, + device: t.device, + st_epochs: int, + learning_rate: float, + st_batch_size: int, + silent_mode: bool = False, + st_from_model: str = None, + keep_noise: bool = False, + **kwargs + ) -> Dict[str, Union[pd.DataFrame, M.STModel]]: """Fit ST Data model Estimate proportion values for @@ -176,23 +177,23 @@ def fit_st_data(st_data : D.CountData, inter = st_data.intersect(R.index) if inter.shape[0] < 1: - print("[ERROR] : No genes overlap in SC and"\ + print("[ERROR] : No genes overlap in SC and" " ST data. Exiting.", - file = sys.stderr, + file=sys.stderr ) sys.exit(-1) - R = R.loc[inter,:] - logits = logits.loc[inter,:] - + R = R.loc[inter, :] + logits = logits.loc[inter, :] t.manual_seed(1337) + # generate ST model st_model = M.STModel(st_data.M, - R = R.values, - logits = logits.values, - device = device, - freeze_beta = kwargs.get("freeze_beta",False), + R=R.values, + logits=logits.values, + device=device, + freeze_beta=kwargs.get("freeze_beta", False) ) # load st model from path if provided if st_from_model is not None: @@ -201,52 +202,55 @@ def fit_st_data(st_data : D.CountData, except: print(' '.join(["Could not load state", "dict from >> {st_from_model}"], - ), - file = sys.stderr, - ) + ), + file=sys.stderr + ) + # estimate proportion values - fit(dataset = st_data, - model = st_model, - loss_tracker = loss_tracker, - device = device, - epochs = st_epochs, - learning_rate = learning_rate, - batch_size = st_batch_size, - silent_mode = silent_mode, + fit(dataset=st_data, + model=st_model, + loss_tracker=loss_tracker, + device=device, + epochs=st_epochs, + learning_rate=learning_rate, + batch_size=st_batch_size, + silent_mode=silent_mode ) # get estimated unadjusted proportions - W = st_model.v.data.cpu().numpy().T + W = st_model.v.data.cpu().numpy().T + # remove dummy cell type proportion values if not keep_noise: - W = W[:,0:st_model.K] + W = W[:, 0:st_model.K] w_columns = R.columns else: w_columns = R.columns.append(pd.Index(["noise"])) + # normalize to obtain adjusted proportions - W = W / W.sum(axis = 1).reshape(-1,1) + W = W / W.sum(axis=1).reshape(-1, 1) + # generate pandas DataFrame from proportions W = pd.DataFrame(W, - index = st_data.index, - columns = w_columns) + index=st_data.index, + columns=w_columns + ) + return {'proportions': W, + 'model': st_model + } - return {'proportions':W, - 'model':st_model, - } -def fit_sc_data(sc_data : D.CountData, - loss_tracker : utils.LossTracker, - device : t.device, - sc_epochs : int, - sc_batch_size : int, +def fit_sc_data(sc_data: D.CountData, + loss_tracker: utils.LossTracker, + device: t.device, + sc_epochs: int, + sc_batch_size: int, learning_rate: float, - silent_mode : bool = False, - sc_from_model : str = None, - **kwargs, - )->Dict[str,Union[pd.DataFrame, - M.ScModel]]: - + silent_mode: bool = False, + sc_from_model: str = None, + **kwargs + ) -> Dict[str, Union[pd.DataFrame, M.ScModel]]: """Fit single cell data sc_data : D.CountData @@ -276,46 +280,47 @@ def fit_sc_data(sc_data : D.CountData, """ - t.manual_seed(1337) + # define single cell model - sc_model = M.ScModel(n_genes = sc_data.G, - n_celltypes = sc_data.Z, - device = device) + sc_model = M.ScModel(n_genes=sc_data.G, + n_celltypes=sc_data.Z, + device=device + ) # load sc-model if provided if sc_from_model is not None and osp.exists(sc_from_model): sc_model.load_state_dict(t.load(sc_from_model)) # fit single cell parameters - fit(dataset = sc_data, - model = sc_model, - loss_tracker = loss_tracker, - device = device, - epochs = sc_epochs, - learning_rate = learning_rate, - batch_size = sc_batch_size, - silent_mode = silent_mode + fit(dataset=sc_data, + model=sc_model, + loss_tracker=loss_tracker, + device=device, + epochs=sc_epochs, + learning_rate=learning_rate, + batch_size=sc_batch_size, + silent_mode=silent_mode ) - # retreive estimated parameter values + # retrieve estimated parameter values logits = sc_model.o.data.cpu().numpy() R = sc_model.R.data.cpu().numpy() + # get cell type names typenames = sc_data.unique_labels() - # generate dataframes for parameters + # generate data frames for parameters R = pd.DataFrame(R, - index = sc_data.genes, - columns = typenames, + index=sc_data.genes, + columns=typenames ) logits = pd.DataFrame(logits, - index = sc_data.genes, - columns = pd.Index(['logits'])) - - return {'rates':R, - 'logits':logits, - 'model':sc_model, - } + index=sc_data.genes, + columns=pd.Index(['logits'])) + return {'rates': R, + 'logits': logits, + 'model': sc_model + } diff --git a/stsc/models.py b/stsc/models.py index 9fe6888..9a3d2e5 100755 --- a/stsc/models.py +++ b/stsc/models.py @@ -4,26 +4,21 @@ import torch as t import torch.nn as nn from torch.nn.parameter import Parameter - - import numpy as np import pandas as pd - from typing import NoReturn, List, Tuple, Union, Collection import logging - import os.path as osp class ScModel(nn.Module): - """ Model for singel cell data """ + """ Model for single cell data """ def __init__(self, - n_genes : int, - n_celltypes : int, - device : t.device, - )->None: - + n_genes: int, + n_celltypes: int, + device: t.device + ) -> None: super().__init__() # Get dimensions from data @@ -31,20 +26,18 @@ def __init__(self, self.G = n_genes # Define parameters to be estimated - self.theta = Parameter(t.Tensor(self.G,self.K).to(device)) - self.R = t.Tensor(self.G,self.K).to(device) - self.o = Parameter(t.Tensor(self.G,1).to(device)) + self.theta = Parameter(t.Tensor(self.G, self.K).to(device)) + self.R = t.Tensor(self.G, self.K).to(device) + self.o = Parameter(t.Tensor(self.G, 1).to(device)) # Initialize parameters nn.init.normal_(self.o, - mean = 0.0, - std = 1.0) + mean=0.0, + std=1.0) nn.init.normal_(self.theta, - mean = 0.0, - std = 1.0) - - + mean=0.0, + std=1.0) # Functions to be used self.nb = t.distributions.NegativeBinomial @@ -52,11 +45,10 @@ def __init__(self, self.logsig = nn.functional.logsigmoid def _llnb(self, - x : t.Tensor, - meta : t.LongTensor, - sf : t.Tensor, - ) -> t.Tensor : - + x: t.Tensor, + meta: t.LongTensor, + sf: t.Tensor + ) -> t.Tensor: """Log Likelihood for NB-model Returns the log likelihood for rates and logodds @@ -70,48 +62,48 @@ def _llnb(self, """ - - log_unnormalized_prob = (sf*self.R[:,meta] * self.logsig(-self.o) + + log_unnormalized_prob = (sf * self.R[:, meta] * self.logsig(-self.o) + x * self.logsig(self.o)) - log_normalization = -t.lgamma(sf*self.R[:,meta] + x) + \ - t.lgamma(1. + x) + \ - t.lgamma(sf*self.R[:,meta]) + log_normalization = -t.lgamma(sf * self.R[:, meta] + x) + \ + t.lgamma(1. + x) + \ + t.lgamma(sf * self.R[:, meta]) ll = t.sum(log_unnormalized_prob - log_normalization) return ll - def forward(self, - x : t.Tensor, - meta : t.LongTensor, - sf : t.Tensor, - **kwargs, - ) -> t.Tensor : + x: t.Tensor, + meta: t.LongTensor, + sf: t.Tensor, + **kwargs + ) -> t.Tensor: """Forward pass during optimization""" # rates for each cell type self.R = self.softpl(self.theta) + # get loss for current parameters - self.loss = -self._llnb(x.transpose(1,0), + self.loss = -self._llnb(x.transpose(1, 0), meta, sf) return self.loss - def __str__(self,): + def __str__(self, ): return f"sc_model" + class STModel(nn.Module): def __init__(self, n_spots: int, - R : np.ndarray, - logits : np.ndarray, - device : t.device, - **kwargs, - )->None: + R: np.ndarray, + logits: np.ndarray, + device: t.device, + **kwargs + ) -> None: super().__init__() @@ -122,7 +114,7 @@ def __init__(self, # Data from single cell estimates; Rates (R) and logits (o) self.R = t.tensor(R.astype(np.float32)).to(device) - self.o = t.tensor(logits.astype(np.float32).reshape(-1,1)).to(device) + self.o = t.tensor(logits.astype(np.float32).reshape(-1, 1)).to(device) # model specific parameters self.softpl = nn.functional.softplus @@ -130,46 +122,45 @@ def __init__(self, self.sig = t.sigmoid # Learn noise from data - self.eta = Parameter(t.tensor(np.zeros((self.G,1)).astype(np.float32)).to(device)) - nn.init.normal_(self.eta, mean = 0.0, std = 1.0) - + self.eta = Parameter(t.tensor(np.zeros((self.G, 1)).astype(np.float32)).to(device)) + nn.init.normal_(self.eta, mean=0.0, std=1.0) # un-normalized proportion in log space - self.theta = Parameter(t.tensor(np.zeros((self.Z,self.S)).astype(np.float32)).to(device)) - nn.init.normal_(self.theta, mean = 0.0,std = 1.0) + self.theta = Parameter(t.tensor(np.zeros((self.Z, self.S)).astype(np.float32)).to(device)) + nn.init.normal_(self.theta, mean=0.0, std=1.0) + # gene bias in log space - if not kwargs.get("freeze_beta",False): - self.beta = Parameter(t.tensor(np.zeros((self.G,1)).astype(np.float32)).to(device)) + if not kwargs.get("freeze_beta", False): + self.beta = Parameter(t.tensor(np.zeros((self.G, 1)).astype(np.float32)).to(device)) self.beta_trans = self.softpl - nn.init.normal_(self.beta, mean = 0.0, std = 0.1) + nn.init.normal_(self.beta, mean=0.0, std=0.1) else: print("Using static beta_g") - self.beta = t.tensor(np.ones((self.G,1)).astype(np.float32)).to(device) - self.beta_trans = lambda x : x + self.beta = t.tensor(np.ones((self.G, 1)).astype(np.float32)).to(device) + self.beta_trans = lambda x: x + # un-normalized proportions - self.v = t.tensor(np.zeros((self.Z,self.S)).astype(np.float32)).to(device) + self.v = t.tensor(np.zeros((self.Z, self.S)).astype(np.float32)).to(device) self.loss = t.tensor(0.0) self.model_ll = 0.0 - - def noise_loss(self, - )-> t.Tensor: + def noise_loss(self + ) -> t.Tensor: """Regularizing term for noise""" - return -0.5*t.sum(t.pow(self.eta,2)) + return -0.5 * t.sum(t.pow(self.eta, 2)) def _llnb(self, - x : t.Tensor, - )->t.Tensor: + x: t.Tensor + ) -> t.Tensor: """Log Likelihood function for standard model""" log_unnormalized_prob = self.r * self.lsig(-self.o) + \ x * self.lsig(self.o) log_normalization = -t.lgamma(self.r + x) + \ - t.lgamma(1. + x) + \ - t.lgamma(self.r) - + t.lgamma(1. + x) + \ + t.lgamma(self.r) ll = t.sum(log_unnormalized_prob - log_normalization) @@ -178,8 +169,8 @@ def _llnb(self, return ll def _lfun(self, - x : t.Tensor, - )-> t.Tensor: + x: t.Tensor + ) -> t.Tensor: """Loss Function Composed of the likelihood and prior of @@ -197,36 +188,39 @@ def _lfun(self, # log likelihood of observed count given model data_loss = self._llnb(x) + # log of prior on noise elements noise_loss = self.noise_loss() - return - data_loss - noise_loss + return - data_loss - noise_loss - def __str__(self, - )-> str: + def __str__(self + ) -> str: return f"st_model" def forward(self, - x : t.tensor, - gidx : t.tensor, - **kwargs, + x: t.tensor, + gidx: t.tensor, + **kwargs ) -> t.tensor: """Forward pass""" self.gidx = gidx + # proportion values self.v = self.softpl(self.theta) + # noise values self.eps = self.softpl(self.eta) + # account for gene specific bias and add noise - self.Rhat = t.cat((t.mul(self.beta_trans(self.beta), self.R),self.eps),dim = 1) + self.Rhat = t.cat((t.mul(self.beta_trans(self.beta), self.R), self.eps), dim=1) + # combinde rates for all cell types - self.r = t.einsum('gz,zs->gs',[self.Rhat,self.v[:,self.gidx]]) + self.r = t.einsum('gz,zs->gs', [self.Rhat, self.v[:, self.gidx]]) + # get loss for current parameters - self.loss = self._lfun(x.transpose(1,0)) + self.loss = self._lfun(x.transpose(1, 0)) return self.loss - - - diff --git a/stsc/parser.py b/stsc/parser.py index 2a6063c..ee8245f 100755 --- a/stsc/parser.py +++ b/stsc/parser.py @@ -2,9 +2,8 @@ import argparse as arp -def make_parser(): - prs = arp.ArgumentParser() +def make_parser(): parser = arp.ArgumentParser() @@ -22,247 +21,208 @@ def make_parser(): # Run Parser Arguments --------------------------------------------- - run_parser.add_argument('-scc','--sc_cnt', - required = False, - type = str, - help = ''.join(["path to single cell", - " count file. Should be", - " on format n_cells x n_genes", - " use flag sct to transpose if", - " if necessary"])) - - run_parser.add_argument('-scl','--sc_labels', - required = False, - type = str, - help = ''.join(["path to single cell", - " labels file. Should be on", - ])) - - run_parser.add_argument('-lcn','--label_colname', - required = False, - default = 'bio_celltype', - type = str, - help = ''.join(["name of columns that", - " cell type labels are", - " listed", - ])) - - - run_parser.add_argument('-scb','--sc_batch_size', - required = False, - default = None, - type = int, - help = ''.join(["batch size for", - " single cell data set", - ])) - - run_parser.add_argument('-stc','--st_cnt', - required = False, - default = None, - nargs = '+', - help = ''.join(["path to spatial", - " transcriptomics count file.", - " Shoul be on form", - " n_spots x n_genes"])) - - run_parser.add_argument('-stm','--st_model', - default = None, - required = False, - help = ''.join(["path to already fitted", - " st model"])) - - run_parser.add_argument('-scm','--sc_model', - required = False, - default = None, - help = ''.join(["path to already fitted", - " sc model"])) - - - run_parser.add_argument('-sce','--sc_epochs', - required = False, - default = 20000, - type = int, - help = ''.join(["number of epochs", - " to be used in fitting", - " of single cell data.", - ])) - - - run_parser.add_argument('-stb','--st_batch_size', - required = False, - default = None, - type = int, - help = ''.join(["batch size for", - " st data set", - ])) - - run_parser.add_argument('-scf','--sc_fit', - required = False, - default = [None,None], - nargs = 2, - help =''.join(["parameters fitted", - " from single cell", - " data. First argument", - " should be path to", - " R-matrix and second", - " to logit vector"]) - ) - - - run_parser.add_argument('-ste','--st_epochs', - default = 20000, - type = int, - help = ''.join(["number of epochs", - " to be used in fitting", - " of spatial transcriptomics", - " data.", - ])) - - run_parser.add_argument('-stt','--st_transpose', - required = False, - default = False, - action = 'store_true', - help = "transpose spatial data") - - - run_parser.add_argument('-sct','--sc_transpose', - required = False, - default = False, - action = 'store_true', - help = "transpose sc data") - - run_parser.add_argument('-kn','--keep_noise', - required = False, - default = False, - action = 'store_true', - help = "keep noise") - - - run_parser.add_argument('-o','--out_dir', - required = False, - default = '', - type = str, - help = ''.join([" full path to output", - " directory. Files will", - " be saved with standard ", - " name and timestamp", - ])) - - run_parser.add_argument('-shh','--silent_mode', - required = False, - default = False, - action = 'store_true', - help = ''.join(["include to silence", - "output throughout", - "fitting", - ])) - - run_parser.add_argument('-n','--topn_genes', - required = False, - default = None, - type = int, - help = ''.join(["only use top n", - " mose highly expressed", - " genes" - ])) - - - run_parser.add_argument('-fg','--filter_genes', - required = False, - default = False, - action = 'store_true', - help = ''.join([f"Filter Ribosomal Genes", - f" and MALAT1", - ])) - - - run_parser.add_argument("-lr","--learning_rate", - required = False, - default = 0.01, - type = float, - help = ''.join([f"learning rate to be", - f" used." - ])) - - - run_parser.add_argument("-mscc","--min_sc_counts", - required = False, - default = 0, - type = float, - help = ''.join([f"minimum number of ", - f" counts for single cells", - f" to be included in", - f" the analysis", - ])) - - run_parser.add_argument("-mstc","--min_st_counts", - required = False, - default = 0, - type = float, - help = ''.join([f"minimum number of ", - f" counts for spots", - f" to be included in", - f" the analysis", - ])) - - - run_parser.add_argument("-mc","--min_cells", - required = False, - default = 0, - type = float, - help = ''.join([f"minimum number of ", - f" cells for genes", - f" to be observed in", - f" the analysis", - ])) - - run_parser.add_argument("-ms","--min_spots", - required = False, - default = 0, - type = float, - help = ''.join([f"minimum number of ", - f" spots for genes", - f" to be observed in", - f" the analysis", - ])) - - - run_parser.add_argument('-gp','--gpu', - required = False, - default = False, - action = 'store_true', - help = ''.join(["use gpu", - ])) - - run_parser.add_argument('-gl','--gene_list', - required = False, - default = None, - type = str, - help = ''.join(["path to list of genes", - " to use", - ])) - - run_parser.add_argument('-sub','--sc_upper_bound', - required = False, - default = None, - type = int, - help = ''.join(["upper bound for single cell", - " subsampling." - ])) - - run_parser.add_argument('-slb','--sc_lower_bound', - required = False, - default = None, - type = int, - help = ''.join(["lower bound for single cell", - " subsampling." - ])) - - run_parser.add_argument('-fb','--freeze_beta', - default = False, - action = "store_true", - help = ''.join(["freeze beta parameter", - ])) + run_parser.add_argument('-scc', '--sc_cnt', + required=False, + type=str, + help='path to single cell ' + 'count file. Should be ' + 'in format n_cells x n_genes ' + 'use flag sct to transpose ' + 'if necessary') + + run_parser.add_argument('-scl', '--sc_labels', + required=False, + type=str, + help='path to single cell labels file. Should be on') + + run_parser.add_argument('-lcn', '--label_colname', + required=False, + default='bio_celltype', + type=str, + help='name of columns that ' + 'cell type labels are listed') + + run_parser.add_argument('-scb', '--sc_batch_size', + required=False, + default=None, + type=int, + help='batch size for single cell data set') + + run_parser.add_argument('-stc', '--st_cnt', + required=False, + default=None, + nargs='+', + help='path to spatial ' + 'transcriptomics count file.' + 'Should be in format n_spots x n_genes') + + run_parser.add_argument('-stm', '--st_model', + default=None, + required=False, + help='path to already fitted st model') + + run_parser.add_argument('-scm', '--sc_model', + required=False, + default=None, + help='path to already fitted sc model') + + run_parser.add_argument('-sce', '--sc_epochs', + required=False, + default=20000, + type=int, + help='number of epochs ' + 'to be used in fitting ' + 'the single cell data') + + run_parser.add_argument('-stb', '--st_batch_size', + required=False, + default=None, + type=int, + help='batch size for the spatial transcriptomics data') + + run_parser.add_argument('-scf', '--sc_fit', + required=False, + default=[None, None], + nargs=2, + help='parameters fitted ' + 'from single cell ' + 'data. First argument ' + 'should be path to ' + 'R-matrix and second ' + 'to logit vector') + + run_parser.add_argument('-ste', '--st_epochs', + default=20000, + type=int, + help='number of epochs ' + 'to be used in fitting ' + 'the spatial transcriptomics data') + + run_parser.add_argument('-stt', '--st_transpose', + required=False, + default=False, + action='store_true', + help='transpose spatial transcriptomics data') + + run_parser.add_argument('-sct', '--sc_transpose', + required=False, + default=False, + action='store_true', + help='transpose single cell data') + + # TODO better description + run_parser.add_argument('-kn', '--keep_noise', + required=False, + default=False, + action='store_true', + help='keep noise') + + run_parser.add_argument('-o', '--out_dir', + required=False, + default='', + type=str, + help='full path to output ' + 'directory. Files will ' + 'be saved with standard ' + 'name and timestamp') + + run_parser.add_argument('-shh', '--silent_mode', + required=False, + default=False, + action='store_true', + help='include to silence ' + 'output throughout fitting') + + run_parser.add_argument('-n', '--topn_genes', + required=False, + default=None, + type=int, + help='only use top n genes based on the ' + 'criteria --top_criteria') + + run_parser.add_argument('-ncr', '--top_criteria', + required=False, + default='hgv', + choices=['hgv', 'expr'], + type=str, + help='criteria to select top genes when --topn_genes is used. ' + 'Options are expr for expression or hgv por variance') + + run_parser.add_argument('-fg', '--filter_genes', + required=False, + default=False, + action='store_true', + help='filter ribosomal genes and Malat1') + + run_parser.add_argument('-lr', '--learning_rate', + required=False, + default=0.01, + type=float, + help='learning rate to be used in the fitting process') + + run_parser.add_argument('-mscc', '--min_sc_counts', + required=False, + default=0, + type=float, + help='minimum number of ' + 'counts for single cells ' + 'to be included in the analysis') + + run_parser.add_argument('-mstc', '--min_st_counts', + required=False, + default=0, + type=float, + help='minimum number of ' + 'counts for spots ' + 'to be included in ' + 'the analysis') + + run_parser.add_argument('-mc', '--min_cells', + required=False, + default=0, + type=float, + help='minimum number of ' + 'cells for genes ' + 'to be observed in the analysis') + + run_parser.add_argument('-ms', '--min_spots', + required=False, + default=0, + type=float, + help='minimum number of ' + 'spots for genes ' + 'to be observed in ' + 'the analysis') + + run_parser.add_argument('-gp', '--gpu', + required=False, + default=False, + action='store_true', + help='use gpu accelerated computation') + + run_parser.add_argument('-gl', '--gene_list', + required=False, + default=None, + type=str, + help= 'path to list of genes to use in the analysis') + + run_parser.add_argument('-sub', '--sc_upper_bound', + required=False, + default=None, + type=int, + help='upper bound limit for single cell subsampling') + + run_parser.add_argument('-slb', '--sc_lower_bound', + required=False, + default=None, + type=int, + help='lower bound limit for single cell subsampling') + + # TODO better description + run_parser.add_argument('-fb', '--freeze_beta', + default=False, + action='store_true', + help= 'freeze beta parameter') # Look Parser Arguments ----------------------------------------------- diff --git a/stsc/progress.py b/stsc/progress.py index b839e6f..d701553 100755 --- a/stsc/progress.py +++ b/stsc/progress.py @@ -3,15 +3,14 @@ import matplotlib.pyplot as plt import os.path as osp import numpy as np - - import sys import time from typing import Tuple -def rolling_average(data : np.ndarray, - windowsize : int, - )-> np.ndarray: + +def rolling_average(data: np.ndarray, + windowsize: int + ) -> np.ndarray: """Compute rolling average Parameter: @@ -32,13 +31,13 @@ def rolling_average(data : np.ndarray, tile = np.ones(windowsize) / windowsize smooth = np.convolve(data, tile, - mode = 'valid', - ) - + mode='valid' + ) return smooth -def get_loss_data(loss_file : str, - )-> Tuple[np.ndarray]: + +def get_loss_data(loss_file: str + ) -> Tuple[np.ndarray]: """Read loss values from file Parameter: @@ -55,30 +54,29 @@ def get_loss_data(loss_file : str, # exit if loss file does not exist if not osp.exists(loss_file): print(' '.join([f"ERROR : the file {loss_file}", - "does not exist", - ] - ), - ) + "does not exist", + ] + ) + ) sys.exit(-1) # read loss files - with open(loss_file,"r+") as fopen: + with open(loss_file, "r+") as fopen: # remove initial and trailing commas loss_history = fopen.read().lstrip(',').rstrip(',') # convert loss history to array - loss_history = np.array([float(x) for \ - x in loss_history.split(',')]) + loss_history = np.array([float(x) for x in loss_history.split(',')]) # generate epoch values - epoch = np.arange(1,loss_history.shape[0]+1) + epoch = np.arange(1, loss_history.shape[0] + 1) + + return (epoch, loss_history) - return (epoch, - loss_history) -def progress(loss_file : str, - windowsize : int, - )-> None: +def progress(loss_file: str, + windowsize: int + ) -> None: """Dynamic plot of loss history Parameter: @@ -86,18 +84,19 @@ def progress(loss_file : str, loss_file : str path to loss history file windowsize : int - size of window to use upon compuation + size of window to use upon computation of rolling average. Should be odd number. """ - # make sure windowsize is int - if not isinstance(windowsize,int): + # make sure windows size is int + if not isinstance(windowsize, int): windowsize = int(windowsize) - # if even windowsize value add one + # if even windows size value add one if windowsize % 2 == 0: windowsize += 1 + # length of array that is lost # in rolling average computation side = int((windowsize - 1) / 2) @@ -105,32 +104,31 @@ def progress(loss_file : str, # create figure and axes fig, ax = plt.subplots(1, 1, - figsize = (8,5), - num = 13, - ) + figsize=(8, 5), + num=13 + ) + # line to represent loss values line1, = ax.plot([], - [], - linestyle = 'dashed', - color = 'black', - ) + [], + linestyle='dashed', + color='black' + ) # line to represent rolling average # values line2, = ax.plot([], [], - color = 'blue', - alpha = 0.2, - linewidth = 5, - ) + color='blue', + alpha=0.2, + linewidth=5 + ) # customize plot - ax.set_ylabel('Loss', - fontsize = 25) - ax.set_xlabel('Epoch', - fontsize = 25) + ax.set_ylabel('Loss', fontsize=25) + ax.set_xlabel('Epoch', fontsize=25) # remove spines - for pos in ['top','right']: + for pos in ['top', 'right']: ax.spines[pos].set_visible(False) # update loss plot every 10th second @@ -138,14 +136,15 @@ def progress(loss_file : str, while keepOn and plt.fignum_exists(13): try: # get loss data from file - xdata,ydata = get_loss_data(loss_file) + xdata, ydata = get_loss_data(loss_file) + # compute rolling average ydata_smooth = rolling_average(ydata, - windowsize = windowsize) + windowsize=windowsize) # get limits for axes - xmin,xmax = xdata.min() - 1, xdata.max() + 1 - ymin,ymax = ydata.min() - 1, ydata.max() + 1 + xmin, xmax = xdata.min() - 1, xdata.max() + 1 + ymin, ymax = ydata.min() - 1, ydata.max() + 1 # update axes line1.set_xdata(xdata) @@ -154,8 +153,8 @@ def progress(loss_file : str, line2.set_xdata(xdata[side:-side]) line2.set_ydata(ydata_smooth) - ax.set_xlim([xmin,xmax]) - ax.set_ylim([ymin,ymax]) + ax.set_xlim([xmin, xmax]) + ax.set_ylim([ymin, ymax]) # try except to catch interactive # closure (CTRL+W) of plot @@ -170,6 +169,6 @@ def progress(loss_file : str, plt.close() keepOn = False + if __name__ == '__main__': - progress(sys.argv[1], - sys.argv[2]) + progress(sys.argv[1], sys.argv[2]) diff --git a/stsc/run.py b/stsc/run.py index 13aa87b..069230f 100755 --- a/stsc/run.py +++ b/stsc/run.py @@ -4,16 +4,11 @@ from os import mkdir, getcwd import os.path as osp import argparse as arp - - import torch as t from torch.cuda import is_available from torch.utils.data import Dataset - import numpy as np import pandas as pd - - import stsc.fit as fit import stsc.datasets as D import stsc.models as M @@ -21,10 +16,9 @@ import stsc.parser as parser -def run(prs : arp.ArgumentParser, - args : arp.Namespace, - )-> None: - +def run(prs: arp.ArgumentParser, + args: arp.Namespace + ) -> None: """Run analysis Depending on specified arguments performs @@ -53,17 +47,17 @@ def run(prs : arp.ArgumentParser, elif not osp.exists(args.out_dir): mkdir(args.out_dir) - # instatiate logger + # instantiate logger log = utils.Logger(osp.join(args.out_dir, '.'.join(['stsc', timestamp, 'log']) - ) - ) + ) + ) # convert args to list if not - args.st_cnt = (args.st_cnt if \ - isinstance(args.st_cnt,list) else \ + args.st_cnt = (args.st_cnt if + isinstance(args.st_cnt, list) else [args.st_cnt]) # set device @@ -80,12 +74,11 @@ def run(prs : arp.ArgumentParser, log.info(' | '.join(["fitting sc data", "count file : {}".format(args.sc_cnt), "labels file : {}".format(args.sc_labels), - ]) - ) + ]) + ) - # control thafbt paths to sc data exists + # control that paths to sc data exists if not all([osp.exists(args.sc_cnt)]): - log.error(' '.join(["One or more of the specified paths to", "the sc data does not exist"])) sys.exit(-1) @@ -97,72 +90,73 @@ def run(prs : arp.ArgumentParser, # Create data set for single cell data sc_data = D.make_sc_dataset(args.sc_cnt, args.sc_labels, - topn_genes = args.topn_genes, - gene_list_pth = args.gene_list, - lbl_colname = args.label_colname, - filter_genes = args.filter_genes, - min_counts = args.min_sc_counts, - min_cells = args.min_cells, - transpose = args.sc_transpose, - lower_bound = args.sc_lower_bound, - upper_bound = args.sc_upper_bound, + topn_genes=args.topn_genes, + top_criteria=args.top_criteria, + gene_list_pth=args.gene_list, + lbl_colname=args.label_colname, + filter_genes=args.filter_genes, + min_counts=args.min_sc_counts, + min_cells=args.min_cells, + transpose=args.sc_transpose, + lower_bound=args.sc_lower_bound, + upper_bound=args.sc_upper_bound ) log.info(' '.join(["SC data GENES : {} ".format(sc_data.G), "SC data CELLS : {} ".format(sc_data.M), "SC data TYPES : {} ".format(sc_data.Z), ]) - ) + ) # generate LossTracker object oname_loss_track = osp.join(args.out_dir, - '.'.join(["sc_loss",timestamp,"txt"]) - ) + '.'.join(["sc_loss", timestamp, "txt"]) + ) + + sc_loss_tracker = utils.LossTracker(oname_loss_track, interval=100) - sc_loss_tracker = utils.LossTracker(oname_loss_track, - interval = 100, - ) # estimate parameters from single cell data - sc_res = fit.fit_sc_data(sc_data, - loss_tracker = sc_loss_tracker, - sc_epochs = args.sc_epochs, - sc_batch_size = args.sc_batch_size, - learning_rate = args.learning_rate, - sc_from_model = args.sc_model, - device = device, + sc_res = fit.fit_sc_data(sc_data, + loss_tracker=sc_loss_tracker, + sc_epochs=args.sc_epochs, + sc_batch_size=args.sc_batch_size, + learning_rate=args.learning_rate, + sc_from_model=args.sc_model, + device=device ) - R,logits,sc_model = sc_res['rates'],sc_res['logits'],sc_res['model'] + R, logits, sc_model = sc_res['rates'], sc_res['logits'], sc_res['model'] + # save sc model oname_sc_model = osp.join(args.out_dir, - '.'.join(['sc_model',timestamp,'pt'])) + '.'.join(['sc_model', timestamp, 'pt'])) - t.save(sc_model.state_dict(),oname_sc_model) + t.save(sc_model.state_dict(), oname_sc_model) # save estimated parameters oname_R = osp.join(args.out_dir, - '.'.join(['R',timestamp,'tsv'])) + '.'.join(['R', timestamp, 'tsv'])) oname_logits = osp.join(args.out_dir, - '.'.join(['logits',timestamp,'tsv'])) + '.'.join(['logits', timestamp, 'tsv'])) - utils.write_file(R,oname_R) - utils.write_file(logits,oname_logits) + utils.write_file(R, oname_R) + utils.write_file(logits, oname_logits) # Load already estimated single cell parameters elif args.st_cnt is not None: log.info(' | '.join(["load sc parameter", "rates (R) : {}".format(args.sc_fit[0]), "logodds (logits) : {}".format(args.sc_fit[1]), - ]) - ) + ]) + ) R = utils.read_file(args.sc_fit[0]) logits = utils.read_file(args.sc_fit[1]) # If ST data is provided estiamte proportions if args.st_cnt[0] is not None: - # generate identifiying tag for each section - sectiontag = list(map(lambda x: '.'.join(osp.basename(x).split('.')[0:-1]),args.st_cnt)) + # generate identifying tag for each section + sectiontag = list(map(lambda x: '.'.join(osp.basename(x).split('.')[0:-1]), args.st_cnt)) log.info("fit st data section(s) : {}".format(args.st_cnt)) # check that provided files exist @@ -174,61 +168,61 @@ def run(prs : arp.ArgumentParser, log.info("loading state from provided st_model") # create data set for st data - st_data = D.make_st_dataset(args.st_cnt, - topn_genes = args.topn_genes, - min_counts = args.min_st_counts, - min_spots = args.min_spots, - filter_genes = args.filter_genes, - transpose = args.st_transpose, + # NOTE we do not want to filter by top n genes in the ST dataset since we + # perform the intersection after fitting the SC models + st_data = D.make_st_dataset(args.st_cnt, + topn_genes=None, + top_criteria=args.top_criteria, + min_counts=args.min_st_counts, + min_spots=args.min_spots, + filter_genes=args.filter_genes, + transpose=args.st_transpose ) log.info(' '.join(["ST data GENES : {} ".format(st_data.G), "ST data SPOTS : {} ".format(st_data.M), ]) - ) + ) # generate LossTracker object oname_loss_track = osp.join(args.out_dir, - '.'.join(["st_loss",timestamp,"txt"]) - ) + '.'.join(["st_loss", timestamp, "txt"]) + ) - st_loss_tracker = utils.LossTracker(oname_loss_track, - interval = 100, - ) + st_loss_tracker = utils.LossTracker(oname_loss_track, interval=100) # estimate proportions of cell types within st data st_res = fit.fit_st_data(st_data, - R = R, - logits = logits, - loss_tracker = st_loss_tracker, - st_epochs = args.st_epochs, - st_batch_size = args.st_batch_size, - learning_rate = args.learning_rate, - silent_mode = args.silent_mode, - st_from_model = args.st_model, - device = device, - keep_noise = args.keep_noise, - freeze_beta = args.freeze_beta, + R=R, + logits=logits, + loss_tracker=st_loss_tracker, + st_epochs=args.st_epochs, + st_batch_size=args.st_batch_size, + learning_rate=args.learning_rate, + silent_mode=args.silent_mode, + st_from_model=args.st_model, + device=device, + keep_noise=args.keep_noise, + freeze_beta=args.freeze_beta ) - - W,st_model = st_res['proportions'],st_res['model'] + W, st_model = st_res['proportions'], st_res['model'] # split joint matrix into multiple wlist = utils.split_joint_matrix(W) # save st model oname_st_model = osp.join(args.out_dir, - '.'.join(['st_model',timestamp,'pt'])) + '.'.join(['st_model', timestamp, 'pt'])) - t.save(st_model.state_dict(),oname_st_model) + t.save(st_model.state_dict(), oname_st_model) # save st data proportion estimates results for s in range(len(wlist)): - section_dir = osp.join(args.out_dir,sectiontag[s]) + section_dir = osp.join(args.out_dir, sectiontag[s]) if not osp.exists(section_dir): mkdir(section_dir) - oname_W = osp.join(section_dir,'.'.join(['W',timestamp,'tsv'])) + oname_W = osp.join(section_dir, '.'.join(['W', timestamp, 'tsv'])) log.info("saving proportions for section {} to {}".format(sectiontag[s], oname_W)) - utils.write_file(wlist[s],oname_W) + utils.write_file(wlist[s], oname_W) diff --git a/stsc/utils.py b/stsc/utils.py index b90d117..5a15b6f 100755 --- a/stsc/utils.py +++ b/stsc/utils.py @@ -6,29 +6,25 @@ import datetime import os.path as osp from typing import NoReturn, List, Tuple, Union, Collection - import numpy as np import pandas as pd import matplotlib.pyplot as plt - import torch as t from torch.utils.data import DataLoader - import stsc.datasets as D import stsc.models as M - import anndata as ad import scipy.sparse as sparse def generate_identifier(): """Generate unique date and time based identifier""" - return re.sub(' |:','',str(datetime.datetime.today())) + return re.sub(' |:', '', str(datetime.datetime.today())) -def make_joint_matrix(pths : List[str], - transpose : bool = False, - )->pd.DataFrame: +def make_joint_matrix(pths: List[str], + transpose: bool = False + ) -> pd.DataFrame: """Generate joint count matrix Generates a joint count matrix from multiple @@ -61,16 +57,14 @@ def make_joint_matrix(pths : List[str], genes = pd.Index([]) # Iterate over all provided paths - for k,pth in enumerate(pths): + for k, pth in enumerate(pths): # read file cnt = read_file(pth) if transpose: cnt = cnt.T mlist.append(cnt) # add file identifier k&- to rownames - index = index.append(pd.Index([str(k) + '&-' + str(x) for \ - x in cnt.index ] )) - + index = index.append(pd.Index([str(k) + '&-' + str(x) for x in cnt.index])) # get union of all observed genes genes = genes.union(cnt.columns) # add length of matrix @@ -83,25 +77,25 @@ def make_joint_matrix(pths : List[str], # prepare joint matrix, rownames are numbers jmat = pd.DataFrame(np.zeros((start_pos[-1], genes.shape[0]) - ), - columns = genes, - ) + ), + columns=genes, + ) # construct joint matrix for k in range(len(start_pos) - 1): # set start and end pos based on # numeric rownames start = start_pos[k] - end = start_pos[k+1] - 1 - jmat.loc[start:end,mlist[k].columns] = mlist[k].values + end = start_pos[k + 1] - 1 + jmat.loc[start:end, mlist[k].columns] = mlist[k].values # set new indices jmat.index = index return jmat -def split_joint_matrix(jmat : pd.DataFrame, - ) -> List[pd.DataFrame]: +def split_joint_matrix(jmat: pd.DataFrame + ) -> List[pd.DataFrame]: """Split joint matrix Splits a joint matrix generated by @@ -120,15 +114,14 @@ def split_joint_matrix(jmat : pd.DataFrame, """ try: - idx, name = zip(*[ idx.split('&-') for \ - idx in jmat.index ]) + idx, name = zip(*[idx.split('&-') for idx in jmat.index]) except: print("_".join([f"Matrix provided is not", f"a joint matrix generated", f"by make_joint_matrix", - ] - ) - ) + ] + ) + ) # convert names to pandas index name = pd.Index(name) @@ -143,18 +136,17 @@ def split_joint_matrix(jmat : pd.DataFrame, # get indices with same identifiers sel = (idx == k) # select indices with same identifiers - tm = jmat.iloc[sel,:] + tm = jmat.iloc[sel, :] # set index to original indices - tm.index = pd.Index([x.replace('&-','_') for \ - x in name[sel].values]) + tm.index = pd.Index([x.replace('&-', '_') for x in name[sel].values]) # append single matrix to list matlist.append(tm) return matlist -def Logger(logname : str , - )->logging.Logger: +def Logger(logname: str + ) -> logging.Logger: """Logger for steroscope run Parameter: @@ -183,6 +175,7 @@ def Logger(logname : str , # streamhandler for stdout output ch = logging.StreamHandler() ch.setLevel(log_level) + # format logger display message formatstr = '[%(asctime)s - %(name)s - %(levelname)s ] >> %(message)s' formatter = logging.Formatter(formatstr) @@ -194,6 +187,7 @@ def Logger(logname : str , return log + class SimpleProgressBar: """ Progress bar to display progress during estimation @@ -213,11 +207,11 @@ class SimpleProgressBar: """ def __init__(self, - max_value : int, - length : int = 20, - symbol : str = "=", - silent_mode : bool = False, - )->None: + max_value: int, + length: int = 20, + symbol: str = "=", + silent_mode: bool = False + ) -> None: self.symbol = symbol self.mx = max_value @@ -233,9 +227,9 @@ def __init__(self, self.call_func = self._verbose def _verbose(self, - epoch : int, - value : float, - ) -> NoReturn: + epoch: int, + value: float + ) -> NoReturn: """Updates progressbar @@ -248,9 +242,9 @@ def _verbose(self, """ - progress = self.symbol*int((epoch / self.delta)) + progress = self.symbol * int((epoch / self.delta)) print(f"\r" - f"Epoch : {epoch +1:<{self.ndigits}}/{self.mx:<{self.ndigits}}" + f"Epoch : {epoch + 1:<{self.ndigits}}/{self.mx:<{self.ndigits}}" f" | Loss : {value:9E}" f" | \x1b[1;37m[" f" \x1b[0;36m{progress:<{self.len}}" @@ -261,18 +255,17 @@ def _verbose(self, def _silent(self, *args, **kwargs, - ) -> NoReturn: + ) -> NoReturn: pass - - def __call__(self, - epoch : int, - value : float, + epoch: int, + value: float, ) -> NoReturn: self.call_func(epoch, value) + class LossTracker: """Keep track of loss @@ -293,16 +286,14 @@ class LossTracker: """ def __init__(self, - opth : str, - interval : int = 100, - )->None: - + opth: str, + interval: int = 100 + ) -> None: self.history = [] self.interval = interval + 1 self.opth = opth - def write_history(self, - )->None: + def write_history(self, ) -> None: """Write loss history to file Will generate a file where each @@ -313,19 +304,16 @@ def write_history(self, """ # use comma separation - with open(self.opth,"a") as fopen: - fopen.writelines(',' + ','.join([str(x) for \ - x in self.history] - ) - ) + with open(self.opth, "a") as fopen: + fopen.writelines(',' + ','.join([str(x) for x in self.history])) # erase loss history once written self.history = [] def __call__(self, - loss : float, - epoch : int, - )-> None: + loss: float, + epoch: int + ) -> None: """Store and write loss history Paramers: @@ -339,46 +327,46 @@ def __call__(self, self.history.append(loss) - if (epoch % self.interval == 0 and \ - epoch >= self.interval): + if (epoch % self.interval == 0 and + epoch >= self.interval): self.write_history() - def __len__(self,): + def __len__(self, ): """length of loss""" return len(self.history) - def current(self,): + + def current(self, ): """current loss value""" return self.history[-1] -def get_extenstion( pth : str) -> str: +def get_extenstion(pth: str + ) -> str: """ get filetype extension""" return osp.splitext(pth)[1][1::] -def grab_anndata_counts(x : ad.AnnData, - ) -> np.ndarray: +def grab_anndata_counts(x: ad.AnnData + ) -> np.ndarray: """Get dense count data from anndata""" - if isinstance(x.X,np.ndarray): + if isinstance(x.X, np.ndarray): return x.X - elif isinstance(x.X,sparse.spmatrix): + elif isinstance(x.X, sparse.spmatrix): return x.X.toarray() elif isinstance(x.X, pd.DataFrame): return x.X.values - else: - print("[ERROR] : unsupported"\ + print("[ERROR] : unsupported" \ " data format : {}. Exiting.".format(type(x.X)), ) sys.exit(-1) -def read_h5ad_sc(cnt_pth : str, - lbl_colname : str = None, - lbl_pth : str = None, - )-> Tuple[pd.DataFrame, - pd.Series]: - + +def read_h5ad_sc(cnt_pth: str, + lbl_colname: str = None, + lbl_pth: str = None + ) -> Tuple[pd.DataFrame, pd.Series]: """read single cell data from h5ad Parameters: @@ -402,33 +390,30 @@ def read_h5ad_sc(cnt_pth : str, _data = ad.read_h5ad(cnt_pth) - _,uni_idx = np.unique(_data.var.index, - return_index = True) - _data = _data[:,uni_idx] + _, uni_idx = np.unique(_data.var.index, return_index=True) + _data = _data[:, uni_idx] if lbl_colname is None: lbl_colname = 0 if lbl_pth is None: - _lbl = _data.obs[[lbl_colname]]\ - .astype(str) + _lbl = _data.obs[[lbl_colname]].astype(str) else: _lbl = read_file(lbl_pth) _lbl = _lbl[lbl_colname] - _data = pd.DataFrame(grab_anndata_counts(_data), - index = _data.obs_names, - columns = _data.var_names, + index=_data.obs_names, + columns=_data.var_names ) _lbl.index = _data.index - return _data,_lbl + return _data, _lbl -def read_h5ad_st(cnt_pth : List[str], - )-> pd.DataFrame: +def read_h5ad_st(cnt_pth: List[str] + ) -> pd.DataFrame: """read spatial data from h5ad Parameters: @@ -450,38 +435,29 @@ def read_h5ad_st(cnt_pth : List[str], """ _cnts = list() - for k,p in enumerate( cnt_pth ): + for k, p in enumerate(cnt_pth): _data = ad.read_h5ad(p) - _,uni_idx = np.unique(_data.var.index, - return_index = True) - _data = _data[:,uni_idx] - - + _, uni_idx = np.unique(_data.var.index, return_index=True) + _data = _data[:, uni_idx] if "x" in _data.obs.keys(): - new_idx = [str(k) + "&-" + str(x)+"x"+str(y) for\ - x,y in zip(_data.obs["x"].values, - _data.obs['y'].values, - )] - + new_idx = [str(k) + "&-" + str(x) + "x" + str(y) for \ + x, y in zip(_data.obs["x"].values, + _data.obs['y'].values, + )] elif "spatial" in _data.obsm.keys(): - - new_idx = [str(k) + "&-" + str(x)+"x"+str(y) for\ - x,y in _data.obsm["spatial"]] + new_idx = [str(k) + "&-" + str(x) + "x" + str(y) for x, y in _data.obsm["spatial"]] else: - new_idx = [str(k) + "&-" + str( x ) for\ - x in _data.obs_names ] + new_idx = [str(k) + "&-" + str(x) for x in _data.obs_names] new_idx = pd.Index(new_idx) _data = pd.DataFrame(grab_anndata_counts(_data), - index = new_idx, - columns = _data.var_names, - ) + index=new_idx, + columns=_data.var_names + ) _cnts.append(_data) - cnts = pd.concat(_cnts, - join = "outer") - - del _cnts,_data + cnts = pd.concat(_cnts, join="outer") + del _cnts, _data cnts[pd.isna(cnts)] = 0.0 cnts = cnts.astype(float) @@ -489,9 +465,9 @@ def read_h5ad_st(cnt_pth : List[str], return cnts -def read_file(file_name : str, - extension : str = None, - )-> pd.DataFrame : +def read_file(file_name: str, + extension: str = None + ) -> pd.DataFrame: """Read file Control if file extension is supported @@ -511,16 +487,16 @@ def read_file(file_name : str, if extension is None: extension = get_extenstion(file_name) - supported = ['tsv','gz',] + supported = ['tsv', 'gz'] if extension not in supported: print(' '.join([f"ERROR: File format {extension}", f"is not yet supported. Please", f"use any of {' '.join(supported)}", - f"formats instead", - ] - ) - ) + f"formats instead" + ] + ) + ) sys.exit(-1) @@ -528,24 +504,24 @@ def read_file(file_name : str, try: compression = ('infer' if extension == 'tsv' else 'gzip') file = pd.read_csv(file_name, - header = 0, - index_col = 0, - compression = compression, - sep = '\t') + header=0, + index_col=0, + compression=compression, + sep='\t') return file except: print(' '.join([f"Something went wrong", f"when trying to read", - f"file >> {file_name}", - ], - ) - ) + f"file >> {file_name}" + ] + ) + ) -def write_file(file : pd.DataFrame, - opth : str, - )-> None: +def write_file(file: pd.DataFrame, + opth: str, + ) -> None: """Write file Parameter: @@ -559,24 +535,23 @@ def write_file(file : pd.DataFrame, try: file.to_csv(opth, - index = True, - header = True, - sep = '\t') + index=True, + header=True, + sep='\t') except: print(' '.join([f"An error occured", f"while trying to write", f"file >> {opth}", - ], - ) - ) - + ] + ) + ) -def subsample_data(cnt : pd.DataFrame, - lbl : pd.DataFrame, - lower_bound : int, - upper_bound : int, - )-> Tuple[pd.DataFrame,pd.DataFrame]: +def subsample_data(cnt: pd.DataFrame, + lbl: pd.DataFrame, + lower_bound: int, + upper_bound: int + ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Subsample single cell data Subsamples the single cell data w.r.t. @@ -613,17 +588,13 @@ def subsample_data(cnt : pd.DataFrame, np.random.seed(1337) - - upper_bound = (upper_bound if upper_bound is\ - not None else np.inf) - lower_bound = (lower_bound if lower_bound is\ - not None else -np.inf) - - assert upper_bound > lower_bound,\ - "upper bound must be larger than"\ + upper_bound = (upper_bound if upper_bound is not None else np.inf) + lower_bound = (lower_bound if lower_bound is not None else -np.inf) + assert upper_bound > lower_bound, \ + "upper bound must be larger than" \ "lower bound" - uni_types = np.unique(lbl) + uni_types = np.unique(lbl) idxs = np.array([]) for ct in uni_types: @@ -634,14 +605,14 @@ def subsample_data(cnt : pd.DataFrame, continue elif n_members > upper_bound: member_idx = np.random.choice(np.where(is_member)[0], - size = upper_bound, - replace = False, + size=upper_bound, + replace=False ) - idxs = np.append(idxs,member_idx) + idxs = np.append(idxs, member_idx) else: - idxs = np.append(idxs,np.where(is_member)[0]) + idxs = np.append(idxs, np.where(is_member)[0]) - cnt = cnt.iloc[idxs,:] + cnt = cnt.iloc[idxs, :] lbl = lbl.iloc[idxs] - return (cnt,lbl) + return cnt, lbl From 0a64db45291c3a9b72abdf13183614a10f3dac40 Mon Sep 17 00:00:00 2001 From: jfnavarro Date: Thu, 29 Apr 2021 09:36:53 +0200 Subject: [PATCH 2/2] Some more reformatting --- setup.py | 50 +++--- stsc/fit.py | 10 +- stsc/look.py | 413 +++++++++++++++++++++++------------------------ stsc/models.py | 2 +- stsc/parser.py | 377 +++++++++++++++++++++--------------------- stsc/progress.py | 8 +- stsc/run.py | 6 +- stsc/test.py | 3 +- stsc/utils.py | 36 ++--- 9 files changed, 427 insertions(+), 478 deletions(-) diff --git a/setup.py b/setup.py index f7cfc92..29cd063 100755 --- a/setup.py +++ b/setup.py @@ -5,29 +5,27 @@ import sys setup(name='stereoscope', - version='0.3.1', - description='Integration of ST and SC data', - author='Alma Andersson', - author_email='alma.andersson@scilifelab.se', - url='http://github.com/almaan/stereoscope', - download_url='https://github.com/almaan/stereoscope/archive/v_03.tar.gz', - license='MIT', - packages=['stsc'], - python_requires='>3.5.6', - install_requires=[ - 'torch>=1.1.0', - 'numba>=0.49.0', - 'numpy>=1.14.0', - 'pandas>=0.25.0', - 'matplotlib>=3.1.0', - 'scikit-learn>=0.20.0', - 'umap-learn>=0.4.1', - 'anndata', - 'scipy', - 'Pillow', - ], - entry_points={'console_scripts': ['stereoscope = stsc.__main__:main', - ] - }, - zip_safe=False) - + version='0.3.2', + description='Integration of ST and SC data', + author='Alma Andersson', + author_email='alma.andersson@scilifelab.se', + url='http://github.com/almaan/stereoscope', + download_url='https://github.com/almaan/stereoscope/archive/v_03.tar.gz', + license='MIT', + packages=['stsc'], + python_requires='>3.5.6', + install_requires=[ + 'torch>=1.1.0', + 'numba>=0.49.0', + 'numpy>=1.14.0', + 'pandas>=0.25.0', + 'matplotlib>=3.1.0', + 'scikit-learn>=0.20.0', + 'umap-learn>=0.4.1', + 'anndata', + 'scipy', + 'Pillow', + ], + entry_points={'console_scripts': ['stereoscope = stsc.__main__:main']}, + zip_safe=False + ) diff --git a/stsc/fit.py b/stsc/fit.py index 8c8b0a4..2a2a6a9 100755 --- a/stsc/fit.py +++ b/stsc/fit.py @@ -112,11 +112,7 @@ def fit(model: Union[M.ScModel, M.STModel], loss_tracker.write_history() except KeyboardInterrupt: - print(' '.join(["\n\nPress Ctrl+C again", - "to interrupt whole process", - ] - ) - ) + print('\n\nPress Ctrl+C again to interrupt whole process') def fit_st_data(st_data: D.CountData, @@ -200,9 +196,7 @@ def fit_st_data(st_data: D.CountData, try: st_model.load_state_dict(t.load(st_from_model)) except: - print(' '.join(["Could not load state", - "dict from >> {st_from_model}"], - ), + print('Could not load state dict from >> {}'.format(st_from_model), file=sys.stderr ) diff --git a/stsc/look.py b/stsc/look.py index 29b6d43..f2267e4 100755 --- a/stsc/look.py +++ b/stsc/look.py @@ -5,14 +5,6 @@ import matplotlib import matplotlib.pyplot as plt import matplotlib.patches as mpatches - -plt.rcParams.update({ - "figure.max_open_warning" : 200, - "font.size" : 15, - "font.family": "calibri", # use serif/main font for text elements -}) - - import sys import re import os @@ -20,34 +12,32 @@ import argparse as arp import warnings from scipy import interpolate - from sklearn.manifold import TSNE from sklearn.decomposition import PCA - from PIL.JpegImagePlugin import JpegImageFile import PIL.Image as Image - - import umap from numba.core.errors import NumbaDeprecationWarning, NumbaWarning +from stsc.utils import make_joint_matrix, split_joint_matrix + +plt.rcParams.update({ + "figure.max_open_warning": 200, + "font.size": 15, + "font.family": "calibri", # use serif/main font for text elements +}) warnings.simplefilter('ignore', category=NumbaDeprecationWarning) warnings.simplefilter('ignore', category=NumbaWarning) -from stsc.utils import make_joint_matrix, split_joint_matrix - -#%% Funtions ------------------------------- -def spltstr(string,size = 20): - rxseps = [' ','-','\\.','_'] +def spltstr(string, size=20): + rxseps = [' ', '-', '\\.', '_'] if len(string) > size: match = re.search('|'.join(rxseps), string[size::]) if match: pos = size + match.start() - strout = spltstr(string[0:pos]) + \ - '\n' + \ - spltstr(string[pos+1::]) + strout = spltstr(string[0:pos]) + '\n' + spltstr(string[pos + 1::]) return strout else: return string @@ -61,27 +51,29 @@ def pd2np(func): to numpy array. """ - def wrapper(*args,**kwargs): + + def wrapper(*args, **kwargs): cargs = list() ckwargs = dict() for arg in args: - if isinstance(arg,pd.DataFrame): + if isinstance(arg, pd.DataFrame): cargs.append(arg.values) else: cargs.append(arg) if len(kwargs) > 0: - for k,v in kwargs.items(): - if isinstance(v,pd.DataFrame): - ckwargs.update({k:v.values}) + for k, v in kwargs.items(): + if isinstance(v, pd.DataFrame): + ckwargs.update({k: v.values}) else: - ckwargs.update({k:v}) + ckwargs.update({k: v}) - return func(*cargs,**ckwargs) + return func(*cargs, **ckwargs) return wrapper -def rotation(v,theta): + +def rotation(v, theta): """Rotation in 3D Rotates a vector in 3D around axis (1,1,1)^T. @@ -92,19 +84,20 @@ def rotation(v,theta): visualization. """ - k = np.ones((3,1)) / np.sqrt(3) - K = np.cross(np.eye(3),k.T) - vrot = v + (np.sin(theta)*np.dot(K,v)) + \ - (1-np.cos(theta))*np.dot(np.dot(K,K),v) + k = np.ones((3, 1)) / np.sqrt(3) + K = np.cross(np.eye(3), k.T) + vrot = v + (np.sin(theta) * np.dot(K, v)) + \ + (1 - np.cos(theta)) * np.dot(np.dot(K, K), v) return vrot + @pd2np -def relfreq(x, ax = 1): +def relfreq(x, ax=1): xs = x.sum(axis=ax) xs[xs == 0] = np.nan - ns = ((-1,1) if ax == 1 else (1,-1)) - f =np.divide(x,xs.reshape(ns)) + ns = ((-1, 1) if ax == 1 else (1, -1)) + f = np.divide(x, xs.reshape(ns)) f[np.isnan(f)] = 0.0 return f @@ -114,29 +107,28 @@ def ax_prop(ta1, x, y, pp, - ms = 80, - ec = "black", - cm = plt.cm.Blues, - mx = [35,35], - mn = [0,0], - alpha = 1, - vmin = 0, - vmax = 1, - threshold = None, + ms=80, + ec="black", + cm=plt.cm.Blues, + mx=[35, 35], + mn=[0, 0], + alpha=1, + vmin=0, + vmax=1, + threshold=None, ): - ta1.set_aspect('equal') ta1.set_xticks([]) ta1.set_yticks([]) - ta1.set_xlim([mn[0]-1,mx[0] +1]) - ta1.set_ylim([mn[1]-1,mx[1]+1]) + ta1.set_xlim([mn[0] - 1, mx[0] + 1]) + ta1.set_ylim([mn[1] - 1, mx[1] + 1]) ta1.autoscale(False) - plot_prop = dict(edgecolors = ec, - s = ms, - vmin = vmin, - vmax = vmax, - ) + plot_prop = dict(edgecolors=ec, + s=ms, + vmin=vmin, + vmax=vmax, + ) if threshold is not None: sup_thrs = pp >= threshold @@ -144,33 +136,33 @@ def ax_prop(ta1, else: sub_thrs = np.array([False]) sup_thrs = np.ones(x.shape[0], - dtype = np.bool) + dtype=np.bool) if sup_thrs.sum() > 0: - ta1.scatter(x = x[sup_thrs], - y = y[sup_thrs], - c = pp[sup_thrs], - cmap = cm, - alpha = alpha, + ta1.scatter(x=x[sup_thrs], + y=y[sup_thrs], + c=pp[sup_thrs], + cmap=cm, + alpha=alpha, **plot_prop, ) if sub_thrs.sum() > 0: - gc = np.zeros((sub_thrs.sum(),4)) - gc[:,3] = pp[sub_thrs] - ta1.scatter(x = x[sub_thrs], - y = y[sub_thrs], - c = gc, + gc = np.zeros((sub_thrs.sum(), 4)) + gc[:, 3] = pp[sub_thrs] + ta1.scatter(x=x[sub_thrs], + y=y[sub_thrs], + c=gc, **plot_prop, ) for v in ta1.axes.spines.values(): - v.set_edgecolor('none') + v.set_edgecolor('none') return ta1 def hide_spines(ax_obj): - if not isinstance(ax_obj,np.ndarray): + if not isinstance(ax_obj, np.ndarray): ax = np.array(ax_obj) else: ax = ax_obj.flatten() @@ -180,27 +172,28 @@ def hide_spines(ax_obj): p.set_xticks([]) p.set_yticks([]) for v in p.axes.spines.values(): - v.set_edgecolor('none') + v.set_edgecolor('none') return ax -def compress(x,method = 'pca'): + +def compress(x, method='pca'): if method.lower() == 'tsne': - dimred = TSNE(n_components = 3, - perplexity = 20, - n_iter = 5000, - learning_rate=10, - n_iter_without_progress= 200, - ) + dimred = TSNE(n_components=3, + perplexity=20, + n_iter=5000, + learning_rate=10, + n_iter_without_progress=200, + ) reduced = dimred.fit_transform(x.values) - elif method.lower() =="umap": - dimred = umap.UMAP(n_neighbors = 25, - n_components = 3) + elif method.lower() == "umap": + dimred = umap.UMAP(n_neighbors=25, + n_components=3) reduced = dimred.fit_transform(x.values) else: - dimred = PCA(n_components = 3) + dimred = PCA(n_components=3) reduced = dimred.fit_transform(x.values) reduced = rgb_transform(reduced) @@ -214,8 +207,7 @@ def ax_compressed(ta3, y, v, hexagonal=False, - marker_size =10): - + marker_size=10): if not hexagonal: xx = x.round(0).astype(int) yy = y.round(0).astype(int) @@ -223,30 +215,29 @@ def ax_compressed(ta3, minX, minY = xx.min(), yy.min() xx = xx - minX yy = yy - minY - maxX, maxY = np.max(xx),np.max(yy) + maxX, maxY = np.max(xx), np.max(yy) - z = np.ones((maxX + 2 ,maxY + 2,3)) + z = np.ones((maxX + 2, maxY + 2, 3)) for ii in range(xx.shape[0]): - z[xx[ii]+1,yy[ii]+1] = v[ii,:] + z[xx[ii] + 1, yy[ii] + 1] = v[ii, :] - ta3.imshow(np.transpose(z,axes=(1,0,2)), - interpolation = 'nearest', - origin = 'lower', - aspect = 'equal') + ta3.imshow(np.transpose(z, axes=(1, 0, 2)), + interpolation='nearest', + origin='lower', + aspect='equal') ta3.grid(False) - else : + else: ta3.scatter(x, y, - c = v, - s = marker_size, - edgecolor = "none", + c=v, + s=marker_size, + edgecolor="none", ) - ta3.set_xticks([]) ta3.set_yticks([]) @@ -255,14 +246,14 @@ def ax_compressed(ta3, return ta3 + def ax_hard(fig, ax, x, y, pp, - marker_size = 10, + marker_size=10, ): - n_types = pp.shape[1] if n_types <= plt.cm.Dark2.N: cmap = plt.cm.Dark2 @@ -271,14 +262,14 @@ def ax_hard(fig, else: cmap = plt.cm.rainbow - max_type = np.argmax(pp.values,axis=1) - color = cmap(max_type /n_types ) + max_type = np.argmax(pp.values, axis=1) + color = cmap(max_type / n_types) ax[0].scatter(x, - y, - c = color, - s = marker_size, - ) + y, + c=color, + s=marker_size, + ) for ii in range(2): ax[ii].set_aspect("equal") @@ -290,71 +281,72 @@ def ax_hard(fig, patches = list() for c in range(n_types): - patches.append(mpatches.Patch(color=cmap(c/n_types), + patches.append(mpatches.Patch(color=cmap(c / n_types), label=pp.columns.values[c]) ) ax[1].legend(handles=patches) - return fig,ax + return fig, ax @pd2np def rgb_transform(y): - eps = 10e-12 - mn = y.min() mx = y.max() - - nm = (y - mn + eps ) / (mx - mn + eps) + nm = (y - mn + eps) / (mx - mn + eps) return nm + def read_file(pth): f = pd.read_csv(pth, - sep = '\t', - header = 0, - index_col = 0) + sep='\t', + header=0, + index_col=0) return f -def get_crd(w,as_he = False): - crd = [x.replace('X','').split('x') for x in w.index.values] +def get_crd(w, as_he=False): + crd = [x.replace('X', '').split('x') for x in w.index.values] crd = np.array(crd).astype(float) if as_he: - crd = crd[:,[1,0]] - crd[:,1] = crd[:,1].max() - crd[:,1] + crd[:,1].min() + crd = crd[:, [1, 0]] + crd[:, 1] = crd[:, 1].max() - crd[:, 1] + crd[:, 1].min() return crd -def resize_by_factor(w,h,f): - wn = int(np.round(w/f)) - hn = int(np.round(h/f)) - return (wn,hn) -def map1d2d(s,n_cols): +def resize_by_factor(w, h, f): + wn = int(np.round(w / f)) + hn = int(np.round(h / f)) + return wn, hn + + +def map1d2d(s, n_cols): j = s % n_cols i = (s - j) / n_cols - return int(i),int(j) + return int(i), int(j) -def crd2array(rgb,crd,w,h,fill= np.nan): - xx = crd[:,0] - yy = crd[:,1] + +def crd2array(rgb, crd, w, h, fill=np.nan): + xx = crd[:, 0] + yy = crd[:, 1] nx = np.arange(h) ny = np.arange(w) - nx, ny = np.meshgrid(nx,ny) - arr = interpolate.griddata((xx,yy), - values = rgb, - xi = (nx,ny), - method = 'cubic', - fill_value = fill + nx, ny = np.meshgrid(nx, ny) + arr = interpolate.griddata((xx, yy), + values=rgb, + xi=(nx, ny), + method='cubic', + fill_value=fill ) return arr -def look(args,): +def look(args, ): get_id = lambda x: '.'.join(osp.basename(x).split('.')[0:-1]) tag = "stsc_viz" @@ -366,107 +358,105 @@ def look(args,): cmap = plt.cm.Blues proppaths = args.proportions_path - if not isinstance(proppaths,list): + if not isinstance(proppaths, list): proppaths = [proppaths] if args.output: odirs = [args.output for x in range(len(proppaths))] else: - odirs = [osp.join(osp.dirname(x),tag) for x in proppaths] - + odirs = [osp.join(osp.dirname(x), tag) for x in proppaths] basenames = [get_id(x) for x in proppaths] snames = [osp.basename(osp.dirname(osp.abspath(pp))) for pp in proppaths] - sortsynonyms = dict(section = 'section', - s = 'section', - i = "internal", - internal = "internal", - celltype = 'ct', + sortsynonyms = dict(section='section', + s='section', + i="internal", + internal="internal", + celltype='ct', ct='ct', - type = 'ct') + type='ct') sort_by = sortsynonyms[args.sort_by] scale_by = sortsynonyms[args.scale_by] allwmat = make_joint_matrix(proppaths) allwmat[allwmat < 0] = 0.0 - allwmat.loc[:,:] = relfreq(allwmat) + allwmat.loc[:, :] = relfreq(allwmat) wlist = split_joint_matrix(allwmat) - crdlist = [get_crd(w,as_he = args.image_orientation) for w in wlist] + crdlist = [get_crd(w, as_he=args.image_orientation) for w in wlist] celltypes = allwmat.columns.tolist() n_sections = len(proppaths) n_celltypes = allwmat.shape[1] n_cols = args.n_cols - # Visualize Cell Type Distribution --------- if sort_by == 'ct': - n_rows = np.max((np.ceil(n_sections / n_cols).astype(int),1)) + n_rows = np.max((np.ceil(n_sections / n_cols).astype(int), 1)) titles = snames outer = n_celltypes inner = n_sections - fignames = [osp.join(odirs[0],''.join([celltypes[x],'.png']))\ + fignames = [osp.join(odirs[0], ''.join([celltypes[x], '.png'])) \ for x in range(n_celltypes)] suptitles = celltypes else: - n_rows = np.max((np.ceil(n_celltypes / n_cols).astype(int),1)) + n_rows = np.max((np.ceil(n_celltypes / n_cols).astype(int), 1)) titles = celltypes outer = n_sections inner = n_celltypes - fignames = [osp.join(odirs[x],''.join([snames[x],'.png'])) \ + fignames = [osp.join(odirs[x], ''.join([snames[x], '.png'])) \ for x in range(n_sections)] suptitles = snames - mxcrd = [np.max(x,axis = 0) for x in crdlist] - mncrd = [np.min(x,axis = 0) for x in crdlist] + mxcrd = [np.max(x, axis=0) for x in crdlist] + mncrd = [np.min(x, axis=0) for x in crdlist] - figsize = ((n_cols + 1) * args.side_size, (n_rows +1 ) * args.side_size + 5) + figsize = ((n_cols + 1) * args.side_size, (n_rows + 1) * args.side_size + 5) for outside in range(outer): - fig,ax = plt.subplots(n_rows,n_cols,figsize = figsize,squeeze = False) - if not isinstance(ax,np.ndarray): + fig, ax = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False) + if not isinstance(ax, np.ndarray): ax = np.array(ax) else: ax = ax.flatten() for inside in range(inner): - section_id = (inside if sort_by == 'ct' else outside ) + section_id = (inside if sort_by == 'ct' else outside) celltype_id = (outside if sort_by == 'ct' else inside) alpha = 0.00 - vmin = (np.quantile(wlist[section_id].iloc[:,celltype_id],alpha) if \ - args.scale_by == 'i' else np.quantile(wlist[section_id].values,alpha)) + vmin = (np.quantile(wlist[section_id].iloc[:, celltype_id], alpha) if \ + args.scale_by == 'i' else np.quantile(wlist[section_id].values, alpha)) - vmax = (np.quantile(wlist[section_id].iloc[:,celltype_id],1-alpha) if \ - args.scale_by == 'i' else np.quantile(wlist[section_id].values,1-alpha)) + vmax = (np.quantile(wlist[section_id].iloc[:, celltype_id], 1 - alpha) if \ + args.scale_by == 'i' else np.quantile(wlist[section_id].values, 1 - alpha)) if args.alpha is not None: if args.alpha_vector: - alpha_vec = wlist[section_id].iloc[:,celltype_id] * args.alpha + alpha_vec = wlist[section_id].iloc[:, celltype_id] * args.alpha else: alpha_vec = args.alpha * np.ones(crdlist[section_id].shape[0]) ax_prop(ax[inside], - crdlist[section_id][:,0], - crdlist[section_id][:,1], - pp = wlist[section_id].iloc[:,celltype_id], - mn = mncrd[section_id], - mx = mxcrd[section_id], - ms = args.marker_size, - cm = cmap, - ec = args.edgecolor, - alpha = args.alpha, - vmin = vmin, - vmax = vmax, - threshold = args.threshold, + crdlist[section_id][:, 0], + crdlist[section_id][:, 1], + pp=wlist[section_id].iloc[:, celltype_id], + mn=mncrd[section_id], + mx=mxcrd[section_id], + ms=args.marker_size, + cm=cmap, + ec=args.edgecolor, + alpha=args.alpha, + vmin=vmin, + vmax=vmax, + threshold=args.threshold, ) ax[inside].set_title(spltstr(titles[inside])) @@ -482,26 +472,26 @@ def look(args,): if args.hard_type: for s in range(n_sections): figpth = osp.join(odirs[s], - '.'.join([snames[s], - 'hard-type.' +\ - args.image_type])) + '.'.join([snames[s], + 'hard-type.' + \ + args.image_type])) if not osp.isdir(odirs[s]): os.mkdir(odirs[s]) - figsize = (args.side_size * 2,args.side_size) - fig,ax = plt.subplots(1,2,figsize = figsize) + figsize = (args.side_size * 2, args.side_size) + fig, ax = plt.subplots(1, 2, figsize=figsize) try: - fig,ax = ax_hard(fig, - ax, - crdlist[s][:,0], - crdlist[s][:,1], - wlist[s], - marker_size = args.marker_size, - ) + fig, ax = ax_hard(fig, + ax, + crdlist[s][:, 0], + crdlist[s][:, 1], + wlist[s], + marker_size=args.marker_size, + ) if args.flip_y: - ax.invert_yaxis() + ax.invert_yaxis() fig.savefig(figpth) except UserWarning: @@ -510,46 +500,44 @@ def look(args,): if args.compress_method is not None: cmpr = compress(allwmat, - method = args.compress_method) + method=args.compress_method) if args.hue_rotate > 0: theta = np.deg2rad(args.hue_rotate) - cmpr = rotation(cmpr.T,theta = theta).T + cmpr = rotation(cmpr.T, theta=theta).T cmpr = rgb_transform(cmpr) if args.shuffle_rgb: pos = np.arange(cmpr.shape[1]) np.random.shuffle(pos) - cmpr = cmpr[:,pos] + cmpr = cmpr[:, pos] - cmpr = pd.DataFrame(dict(r = cmpr[:,0], - g = cmpr[:,1], - b = cmpr[:,2], + cmpr = pd.DataFrame(dict(r=cmpr[:, 0], + g=cmpr[:, 1], + b=cmpr[:, 2], ), - index = allwmat.index, - ) + index=allwmat.index, + ) scmpr = split_joint_matrix(cmpr) ct_cols = n_cols ct_skip = 1 - if not args.gathered_compr or\ - len(args.proportions_path) < 2: + if not args.gathered_compr or \ + len(args.proportions_path) < 2: for s in range(n_sections): figpth = osp.join(odirs[s], - '.'.join([snames[s], - 'compressed.' +\ - args.image_type])) + '.'.join([snames[s], 'compressed.' + args.image_type])) if not osp.isdir(odirs[s]): os.mkdir(odirs[s]) - figsize = (args.side_size,args.side_size) - fig,ax = plt.subplots(1,1,figsize = figsize) - ax_compressed(ax,crdlist[s][:,0],crdlist[s][:,1], + figsize = (args.side_size, args.side_size) + fig, ax = plt.subplots(1, 1, figsize=figsize) + ax_compressed(ax, crdlist[s][:, 0], crdlist[s][:, 1], scmpr[s], - hexagonal = args.hexagonal, - marker_size = args.marker_size, + hexagonal=args.hexagonal, + marker_size=args.marker_size, ) if args.flip_y: ax.invert_yaxis() @@ -559,31 +547,32 @@ def look(args,): else: n_cmpr_cols = args.n_cols n_cmpr_rows = np.ceil(n_sections / n_cmpr_cols).astype(int) - figpth = osp.join(odirs[0],'joint.compressed.' + args.image_type) - if not osp.exists(odirs[0]): os.mkdir(odirs[0]) + figpth = osp.join(odirs[0], 'joint.compressed.' + args.image_type) + if not osp.exists(odirs[0]): + os.mkdir(odirs[0]) - figsize = ((n_cmpr_cols + 1) * args.side_size, (n_cmpr_rows +1) * args.side_size + 1.0) + figsize = ((n_cmpr_cols + 1) * args.side_size, (n_cmpr_rows + 1) * args.side_size + 1.0) - fig, ax = plt.subplots(n_cmpr_rows,n_cmpr_cols, - figsize = figsize, - constrained_layout = True) + fig, ax = plt.subplots(n_cmpr_rows, n_cmpr_cols, + figsize=figsize, + constrained_layout=True) ax = ax.reshape(n_cmpr_rows, n_cmpr_cols) for s in range(n_sections): - r,c = map1d2d(s,n_cmpr_cols) - ax_compressed(ax[r,c], - crdlist[s][:,0], - crdlist[s][:,1], + r, c = map1d2d(s, n_cmpr_cols) + ax_compressed(ax[r, c], + crdlist[s][:, 0], + crdlist[s][:, 1], scmpr[s], hexagonal=args.hexagonal, - marker_size = args.marker_size, + marker_size=args.marker_size, ) - ax[r,c].set_title(spltstr(snames[s]), - fontsize = 10) + ax[r, c].set_title(spltstr(snames[s]), + fontsize=10) - if args.flip_y: ax[r,c].invert_yaxis() + if args.flip_y: ax[r, c].invert_yaxis() hide_spines(ax) fig.savefig(figpth) diff --git a/stsc/models.py b/stsc/models.py index 9a3d2e5..d8053f0 100755 --- a/stsc/models.py +++ b/stsc/models.py @@ -217,7 +217,7 @@ def forward(self, # account for gene specific bias and add noise self.Rhat = t.cat((t.mul(self.beta_trans(self.beta), self.R), self.eps), dim=1) - # combinde rates for all cell types + # combine rates for all cell types self.r = t.einsum('gz,zs->gs', [self.Rhat, self.v[:, self.gidx]]) # get loss for current parameters diff --git a/stsc/parser.py b/stsc/parser.py index ee8245f..2c1be96 100755 --- a/stsc/parser.py +++ b/stsc/parser.py @@ -4,22 +4,21 @@ def make_parser(): - parser = arp.ArgumentParser() - subparsers = parser.add_subparsers(dest = 'command') + subparsers = parser.add_subparsers(dest='command') run_parser = subparsers.add_parser("run", - formatter_class=arp\ + formatter_class=arp \ .ArgumentDefaultsHelpFormatter) look_parser = subparsers.add_parser("look", - formatter_class=arp\ + formatter_class=arp \ .ArgumentDefaultsHelpFormatter) test_parser = subparsers.add_parser("test") progress_parser = subparsers.add_parser('progress', - formatter_class=arp\ + formatter_class=arp \ .ArgumentDefaultsHelpFormatter) -# Run Parser Arguments --------------------------------------------- + # Run Parser Arguments --------------------------------------------- run_parser.add_argument('-scc', '--sc_cnt', required=False, @@ -204,7 +203,7 @@ def make_parser(): required=False, default=None, type=str, - help= 'path to list of genes to use in the analysis') + help='path to list of genes to use in the analysis') run_parser.add_argument('-sub', '--sc_upper_bound', required=False, @@ -222,195 +221,179 @@ def make_parser(): run_parser.add_argument('-fb', '--freeze_beta', default=False, action='store_true', - help= 'freeze beta parameter') - - -# Look Parser Arguments ----------------------------------------------- - - look_parser.add_argument("-pp","--proportions_path", - type = str, - nargs = '+', - required = True, - help = ''.join([f"Path to proportions", - f" file generated by", - f" st2sc. Named W*.tsv"]) - ) - - - look_parser.add_argument("-c","--compress_method", - type = str, - required = False, - default = None, - help = ''.join([f"method to be used", - f" for compression of", - f" information."]), - ) - - look_parser.add_argument("-ms","--marker_size", - type = int, - required = False, - default = 100, - help = ''.join([f"size of scatterplot", - f" markers." - ])) - - look_parser.add_argument("-o","--output", - type = str, - required = False, - default = '', - help = ''.join([f"Path to output", - f" can either be", - f" a directory or", - f" filename. If only", - f" dir is given, same", - f" basename os for pp", - f" is used."]) - ) - - look_parser.add_argument("-nc","--n_cols", - default = 2, - type = int, - required = False, - ) - - look_parser.add_argument("-y","--flip_y", - required = False, - default = False, - action = 'store_true', - ) - - look_parser.add_argument("-sb","--sort_by", - required = False, - default = 'ct', - type = str) - - look_parser.add_argument("-sc","--scale_by", - required = False, - default = 'ct', - type = str) - - - look_parser.add_argument("-gb","--gathered_compr", - required = False, - default = False, - action = 'store_true', - ) - - look_parser.add_argument("-sf","--scaling_factor", - required = False, - type = float, - default = 4.0, - help = ''.join([]), - ) - - look_parser.add_argument("-hu","--hue_rotate", - required = False, - type = float, - default = -1.0, - help = ''.join([]), - ) - - - look_parser.add_argument("-hex","--hexagonal", - required = False, - default = False, - action = 'store_true', - help = ''.join([]), - ) - - look_parser.add_argument("-ec","--edgecolor", - required =False, - type = str, - default = 'black', - help = "spot edgecolor", - ) - - look_parser.add_argument("-ext","--image_type", - required =False, - type = str, - default = 'png', - help = "image file type", - ) - - look_parser.add_argument("-al","--alpha", - required =False, - type = float, - default = 1, - help = "facecolor alpha", - ) - look_parser.add_argument("-av","--alpha_vector", - required =False, - default = False, - action = 'store_true', - help = "use value based alpha", - ) - - look_parser.add_argument("-cm","--colormap", - required =False, - default = "Blues", - type = str, - help = "name of matplotlib"\ - " colormap to use." - , - ) - - look_parser.add_argument("-thr","--threshold", - required =False, - default = None, - type = float, - help = "threshold value for"\ - "proportion visualization", - ) - - look_parser.add_argument("-ht","--hard_type", - required =False, - default = False, - action = "store_true", - help = "make hard type plot", - ) - - look_parser.add_argument("-io","--image_orientation", - required =False, - default = False, - action = "store_true", - help = "arrange capture locations"\ - " in same orientation as HE image", - ) - - - look_parser.add_argument("-ss","--side_size", - required =False, - default = 350, - type = float, - help = "subplot side size", - ) - - - look_parser.add_argument("-shu","--shuffle_rgb", - required = False, - default = False, - action = 'store_true', - help = ''.join(["Shuffle RGB colors", - " in the compressed", - " visualization", - ]), - ) - - progress_parser.add_argument("-lf",'--loss_file', - required = True, - help = ''.join(['path to loss', - 'data', - ] - ), - ) - - progress_parser.add_argument("-ws",'--window_size', - required = False, - default = 11, - help = ''.join(['window size for', - 'rolling average', - ] - ), - ) + help='freeze beta parameter') + + # Look Parser Arguments ----------------------------------------------- + + look_parser.add_argument('-pp', '--proportions_path', + type=str, + nargs='+', + required=True, + help='path to proportions ' + 'file generated by ' + 'stereoscope run. Named W*.tsv' + ) + # TODO add choices here + look_parser.add_argument('-c', '--compress_method', + type=str, + required=False, + default=None, + help='method to be used ' + 'for compression of ' + 'information' + ) + + look_parser.add_argument('-ms', '--marker_size', + type=int, + required=False, + default=100, + help='size of scatter plot markers' + ) + + look_parser.add_argument('-o', '--output', + type=str, + required=False, + default='', + help='path to output ' + 'can either be ' + 'a directory or a ' + 'filename. If only ' + 'a dir is given, same ' + 'basename the the plots is used' + ) + + look_parser.add_argument('-nc', '--n_cols', + default=2, + type=int, + required=False + ) + + look_parser.add_argument('-y', '--flip_y', + required=False, + default=False, + action='store_true' + ) + + look_parser.add_argument('-sb', '--sort_by', + required=False, + default='ct', + type=str + ) + + look_parser.add_argument('-sc', '--scale_by', + required=False, + default='ct', + type=str + ) + + look_parser.add_argument('-gb', '--gathered_compr', + required=False, + default=False, + action='store_true' + ) + + look_parser.add_argument('-sf', '--scaling_factor', + required=False, + type=float, + default=4.0 + ) + + look_parser.add_argument('-hu', '--hue_rotate', + required=False, + type=float, + default=-1.0 + ) + + look_parser.add_argument('-hex', '--hexagonal', + required=False, + default=False, + action='store_true' + ) + + look_parser.add_argument('-ec', '--edgecolor', + required=False, + type=str, + default='black', + help="spot edgecolor" + ) + + look_parser.add_argument('-ext', '--image_type', + required=False, + type=str, + default='png', + help='image file type' + ) + + look_parser.add_argument('-al', '--alpha', + required=False, + type=float, + default=1, + help='facecolor alpha', + ) + + look_parser.add_argument('-av', '--alpha_vector', + required=False, + default=False, + action='store_true', + help='use value based alpha', + ) + + look_parser.add_argument('-cm', '--colormap', + required=False, + default="Blues", + type=str, + help='name of matplotlib colormap to use' + ) + + look_parser.add_argument('-thr', '--threshold', + required=False, + default=None, + type=float, + help='threshold value for proportion visualization' + ) + + look_parser.add_argument('-ht', '--hard_type', + required=False, + default=False, + action="store_true", + help='make hard type plot' + ) + + look_parser.add_argument('-io', '--image_orientation', + required=False, + default=False, + action="store_true", + help='arrange capture locations ' + 'in same orientation as HE image' + ) + + look_parser.add_argument('-ss', '--side_size', + required=False, + default=350, + type=float, + help='subplot side size', + ) + + look_parser.add_argument('-shu', '--shuffle_rgb', + required=False, + default=False, + action='store_true', + help='shuffle RGB colors ' + 'in the compressed ' + 'visualization' + ) + + # Progress Parser Arguments ----------------------------------------------- + + progress_parser.add_argument('-lf', '--loss_file', + required=True, + help='path to loss data' + ) + + progress_parser.add_argument('-ws', '--window_size', + required=False, + default=11, + help='window size for rolling average' + ) return parser diff --git a/stsc/progress.py b/stsc/progress.py index d701553..a53c1bb 100755 --- a/stsc/progress.py +++ b/stsc/progress.py @@ -53,11 +53,7 @@ def get_loss_data(loss_file: str """ # exit if loss file does not exist if not osp.exists(loss_file): - print(' '.join([f"ERROR : the file {loss_file}", - "does not exist", - ] - ) - ) + print('ERROR : the file {} does not exist'.format(loss_file)) sys.exit(-1) # read loss files @@ -71,7 +67,7 @@ def get_loss_data(loss_file: str # generate epoch values epoch = np.arange(1, loss_history.shape[0] + 1) - return (epoch, loss_history) + return epoch, loss_history def progress(loss_file: str, diff --git a/stsc/run.py b/stsc/run.py index 069230f..c67fea7 100755 --- a/stsc/run.py +++ b/stsc/run.py @@ -79,8 +79,8 @@ def run(prs: arp.ArgumentParser, # control that paths to sc data exists if not all([osp.exists(args.sc_cnt)]): - log.error(' '.join(["One or more of the specified paths to", - "the sc data does not exist"])) + log.error('One or more of the specified paths to ' + 'the sc data does not exist') sys.exit(-1) # load pre-fitted model if provided @@ -153,7 +153,7 @@ def run(prs: arp.ArgumentParser, R = utils.read_file(args.sc_fit[0]) logits = utils.read_file(args.sc_fit[1]) - # If ST data is provided estiamte proportions + # If ST data is provided estimate proportions if args.st_cnt[0] is not None: # generate identifying tag for each section sectiontag = list(map(lambda x: '.'.join(osp.basename(x).split('.')[0:-1]), args.st_cnt)) diff --git a/stsc/test.py b/stsc/test.py index db1b224..c02e6a0 100755 --- a/stsc/test.py +++ b/stsc/test.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 + def test(): print("Successfully installed stereoscope CLI") - - diff --git a/stsc/utils.py b/stsc/utils.py index 5a15b6f..7b8b581 100755 --- a/stsc/utils.py +++ b/stsc/utils.py @@ -116,11 +116,9 @@ def split_joint_matrix(jmat: pd.DataFrame try: idx, name = zip(*[idx.split('&-') for idx in jmat.index]) except: - print("_".join([f"Matrix provided is not", - f"a joint matrix generated", - f"by make_joint_matrix", - ] - ) + print("Matrix provided is not " + "a joint matrix generated " + "by make_joint_matrix" ) # convert names to pandas index @@ -357,8 +355,8 @@ def grab_anndata_counts(x: ad.AnnData elif isinstance(x.X, pd.DataFrame): return x.X.values else: - print("[ERROR] : unsupported" \ - " data format : {}. Exiting.".format(type(x.X)), + print("[ERROR] : unsupported " + "data format : {}. Exiting.".format(type(x.X)), ) sys.exit(-1) @@ -490,12 +488,10 @@ def read_file(file_name: str, supported = ['tsv', 'gz'] if extension not in supported: - print(' '.join([f"ERROR: File format {extension}", - f"is not yet supported. Please", - f"use any of {' '.join(supported)}", - f"formats instead" - ] - ) + print('"ERROR: File format {} ' + 'is not yet supported. Please ' + 'use any of these {} ' + 'formats instead'.format(extension, ' '.join(supported)) ) sys.exit(-1) @@ -511,11 +507,8 @@ def read_file(file_name: str, return file except: - print(' '.join([f"Something went wrong", - f"when trying to read", - f"file >> {file_name}" - ] - ) + print('Something went wrong ' + 'when trying to read file >> {}'.format(file_name) ) @@ -539,11 +532,8 @@ def write_file(file: pd.DataFrame, header=True, sep='\t') except: - print(' '.join([f"An error occured", - f"while trying to write", - f"file >> {opth}", - ] - ) + print('An error occurred ' + 'while trying to write file >> {}'.format(opth) )