diff --git a/src/grelu/resources/__init__.py b/src/grelu/resources/__init__.py index 362cc4d..aa1efd1 100644 --- a/src/grelu/resources/__init__.py +++ b/src/grelu/resources/__init__.py @@ -2,6 +2,7 @@ import importlib_resources from tempfile import TemporaryDirectory from pathlib import Path +from typing import Optional, List, Dict, Any import wandb from grelu.lightning import LightningModel @@ -10,16 +11,16 @@ DEFAULT_WANDB_HOST = 'https://api.wandb.ai' -def get_meme_file_path(meme_motif_db): +def get_meme_file_path(meme_motif_db: str) -> str: """ Return the path to a MEME file. Args: - meme_motif_db (str): Path to a MEME file or the name of a MEME file included with gReLU. + meme_motif_db: Path to a MEME file or the name of a MEME file included with gReLU. Current name options are "jaspar" and "consensus". Returns: - (str): Path to the specified MEME file. + Path to the specified MEME file. """ if meme_motif_db == "jaspar": meme_motif_db = ( @@ -41,13 +42,17 @@ def get_meme_file_path(meme_motif_db): raise Exception(f"{meme_motif_db} is not a valid file.") -def get_default_config_file(): - config = importlib_resources.files("grelu") / "resources" / "default_config.yaml" - assert config.exists() - return str(config) +def get_blacklist_file(genome: str) -> str: + """ + Return the path to a blacklist file + Args: + genome: Name of a genome whose blacklist file is included with gReLU. + Current name options are "hg19", "hg38" and "mm10". -def get_blacklist_file(genome): + Returns: + Path to the specified blacklist file. + """ blacklist = ( importlib_resources.files("grelu") / "resources" @@ -59,11 +64,26 @@ def get_blacklist_file(genome): return str(blacklist) -def _check_wandb(host=DEFAULT_WANDB_HOST): - assert wandb.login(host=host), f'Weights & Biases (wandb) is not configured, see {DEFAULT_WANDB_HOST}/authorize' +def _check_wandb(host:str=DEFAULT_WANDB_HOST) -> None: + """ + Check that the user is logged into Weights and Biases + + Args: + host: URL of the Weights & Biases host + """ + assert wandb.login(host=host, anonymous="allow"), f'Weights & Biases (wandb) is not configured, see {host}/authorize' + + +def projects(host: str=DEFAULT_WANDB_HOST) -> List[str]: + """ + List all projects in the model zoo + Args: + host: URL of the Weights & Biases host -def projects(host=DEFAULT_WANDB_HOST): + Returns: + List of project names + """ _check_wandb(host=host) api = wandb.Api() @@ -71,7 +91,19 @@ def projects(host=DEFAULT_WANDB_HOST): return [p.name for p in projects] -def artifacts(project, host=DEFAULT_WANDB_HOST, type_is=None, type_contains=None): +def artifacts(project: str, host: str=DEFAULT_WANDB_HOST, type_is: Optional[str]=None, type_contains: Optional[str]=None) -> List[str]: + """ + List all artifacts associated with a project in the model zoo + + Args: + project: Name of the project to search + host: URL of the Weights & Biases host + type_is: Return only artifacts with this type + type_contains: Return only artifacts whose type contains this string + + Returns: + List of artifact names + """ _check_wandb(host) project_path = f'{DEFAULT_WANDB_ENTITY}/{project}' @@ -90,15 +122,47 @@ def artifacts(project, host=DEFAULT_WANDB_HOST, type_is=None, type_contains=None return arts -def models(project, host=DEFAULT_WANDB_HOST): +def models(project:str, host:str=DEFAULT_WANDB_HOST) -> List[str]: + """ + List all models associated with a project in the model zoo + + Args: + project: Name of the project to search + host: URL of the Weights & Biases host + + Returns: + List of model names + """ return artifacts(project, host=host, type_contains='model') -def datasets(project, host=DEFAULT_WANDB_HOST): +def datasets(project:str, host:str=DEFAULT_WANDB_HOST) -> List[str]: + """ + List all datasets associated with a project in the model zoo + + Args: + project: Name of the project to search + host: URL of the Weights & Biases host + + Returns: + List of dataset names + """ return artifacts(project, host=host, type_contains='dataset') -def runs(project, host=DEFAULT_WANDB_HOST, field='id', filters=None): +def runs(project:str, host:str=DEFAULT_WANDB_HOST, field:str='id', filters: Optional[Dict[str, Any]]=None) -> List[str]: + """ + List attributes of all runs associated with a project in the model zoo + + Args: + project: Name of the project to search + host: URL of the Weights & Biases host + field: Field to return from the run metadata + filters: Dictionary of filters to pass to `api.runs` + + Returns: + List of run attributes + """ _check_wandb(host=host) project_path = f'{DEFAULT_WANDB_ENTITY}/{project}' @@ -106,30 +170,78 @@ def runs(project, host=DEFAULT_WANDB_HOST, field='id', filters=None): return [getattr(run, field) for run in api.runs(project_path, filters=filters)] -def get_artifact(name, project, alias='latest'): - _check_wandb() +def get_artifact(name:str, project:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest'): + """ + Retrieve an artifact associated with a project in the model zoo + + Args: + name: Name of the artifact + project: Name of the project containing the artifact + host: URL of the Weights & Biases host + alias: Alias of the artifact + + Returns: + The specific artifact + """ + _check_wandb(host=host) project_path = f'{DEFAULT_WANDB_ENTITY}/{project}' api = wandb.Api() return api.artifact(f'{project_path}/{name}:{alias}') -def get_dataset_by_model(model_name, project, alias='latest'): - art = get_artifact(model_name, project, alias=alias) +def get_dataset_by_model(model_name:str, project:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest') -> List[str]: + """ + List all datasets associated with a model in the model zoo + + Args: + model_name: Name of the model + project: Name of the project containing the model + host: URL of the Weights & Biases host + alias: Alias of the model artifact + + Returns: + A list containing the names of all datasets linked to the model + """ + art = get_artifact(model_name, project, host=host, alias=alias) run = art.logged_by() return [x.name for x in run.used_artifacts()] -def get_model_by_dataset(dataset_name, project, alias='latest'): - art = get_artifact(dataset_name, project, alias=alias) +def get_model_by_dataset(dataset_name:str, project:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest') -> List[str]: + """ + List all models associated with a dataset in the model zoo + + Args: + dataset_name: Name of the dataset + project: Name of the project containing the dataset + host: URL of the Weights & Biases host + alias: Alias of the dataset artifact + + Returns: + A list containing the names of all models linked to the dataset + """ + art = get_artifact(dataset_name, project, host=host, alias=alias) runs = art.used_by() assert len(runs) > 0 return [x.name for x in runs[0].logged_artifacts()] -def load_model(project, model_name, alias='latest', checkpoint_file='model.ckpt'): +def load_model(project:str, model_name:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest', checkpoint_file:str='model.ckpt') -> LightningModel: + """ + Download and load a model from the model zoo - art = get_artifact(model_name, project, alias=alias) + Args: + project: Name of the project containing the model + model_name: Name of the model + host: URL of the Weights & Biases host + alias: Alias of the model artifact + checkpoint_file: Name of the checkpoint file contained in the model artifact + + Returns: + A LightningModel object + """ + art = get_artifact(model_name, project, host=host, alias=alias) with TemporaryDirectory() as d: art.download(d)