Skip to content

Commit

Permalink
Model file parameterization
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 3, 2023
1 parent 8be8622 commit fe8a653
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
5 changes: 4 additions & 1 deletion inferelator/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ def load_data_h5ad(
gene_name_column
)

if use_layer is None:
use_layer = "X"

# Make sure layer is in the anndata object
if use_layer is not None and use_layer not in data.layers:
if use_layer != "X" and use_layer not in data.layers:
raise ValueError(
f"Layer {use_layer} is not in {h5ad_file}"
)
Expand Down
30 changes: 23 additions & 7 deletions inferelator/workflows/workflow_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

_VALID_FILE_TYPES = [_TSV, _H5AD, _HDF5, _MTX, _TENX]


class WorkflowBaseLoader(object):
"""
WorkflowBaseLoader is the class to load raw data.
Expand Down Expand Up @@ -251,7 +252,13 @@ def set_expression_file(
:type h5_layer: str, optional
"""

nones = [tsv is None, hdf5 is None, h5ad is None, tenx_path is None, mtx is None]
nones = [
tsv is None,
hdf5 is None,
h5ad is None,
tenx_path is None,
mtx is None
]

if all(nones):
Debug.vprint("No file provided", level=0)
Expand Down Expand Up @@ -570,9 +577,11 @@ def append_to_path(
"""
Add a string to an existing path variable
:param var_name: The name of the path variable (`input_dir` or `output_dir`)
:param var_name: The name of the path variable
(`input_dir` or `output_dir`)
:type var_name: str
:param to_append: The path to join to the end of the existing path variable
:param to_append: The path to join to the end of
the existing path variable
:type to_append: str
"""
Expand Down Expand Up @@ -756,7 +765,7 @@ def read_genes(self, file=None):
assert genes.shape[1] == 1
self.gene_names = genes.values.flatten().tolist()

# Use the gene names in the data file if no restrictive list is provided
# Use the gene names in the data file if no list is provided
if self.gene_names is None and self.data is not None:
self.gene_names = self.data.gene_names.copy()

Expand Down Expand Up @@ -1119,8 +1128,8 @@ def set_crossvalidation_parameters(

if not self.split_gold_standard_for_crossvalidation:
warnings.warn(
"The split_gold_standard_for_crossvalidation flag is not set. "
"Other options may be ignored"
"The split_gold_standard_for_crossvalidation "
"flag is not set. Other options may be ignored."
)

def set_shuffle_parameters(
Expand Down Expand Up @@ -1194,10 +1203,12 @@ def set_output_file_names(
confidence_file_name="",
nonzero_coefficient_file_name="",
pdf_curve_file_name="",
curve_data_file_name=""
curve_data_file_name="",
model_h5_file_name=""
):
"""
Set output file names. File names that end in '.gz' will be gzipped.
Set any file name to None to prevent it from being generated
:param network_file_name: Long-format network TSV file with
TF->Gene edge information.
Expand All @@ -1217,6 +1228,9 @@ def set_output_file_names(
:param curve_data_file_name: TSV file with the data used to plot
curves. Default is None (this file is not produced).
:type curve_data_file_name: str
:param model_h5_file_name: H5 file with model priors, coefficients,
and run parameters saved
:type model_h5_file_name: str
"""

if network_file_name != "":
Expand All @@ -1229,6 +1243,8 @@ def set_output_file_names(
IR.curve_file_name = pdf_curve_file_name
if curve_data_file_name != "":
IR.curve_data_file_name = curve_data_file_name
if model_h5_file_name != "":
IR.model_file_name = model_h5_file_name

def set_run_parameters(
self,
Expand Down

0 comments on commit fe8a653

Please sign in to comment.