Skip to content

Commit

Permalink
Merge pull request #13 from Genentech/add_host
Browse files Browse the repository at this point in the history
added host and docstrings to resources module
  • Loading branch information
avantikalal authored Jul 17, 2024
2 parents 0a0f060 + cbafec5 commit b7336fa
Showing 1 changed file with 135 additions and 23 deletions.
158 changes: 135 additions & 23 deletions src/grelu/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib_resources
from tempfile import TemporaryDirectory
from pathlib import Path
from typing import Optional, List, Dict, Any

import wandb
from grelu.lightning import LightningModel
Expand All @@ -10,16 +11,16 @@
DEFAULT_WANDB_HOST = 'https://api.wandb.ai'


def get_meme_file_path(meme_motif_db):
def get_meme_file_path(meme_motif_db: str) -> str:
"""
Return the path to a MEME file.
Args:
meme_motif_db (str): Path to a MEME file or the name of a MEME file included with gReLU.
meme_motif_db: Path to a MEME file or the name of a MEME file included with gReLU.
Current name options are "jaspar" and "consensus".
Returns:
(str): Path to the specified MEME file.
Path to the specified MEME file.
"""
if meme_motif_db == "jaspar":
meme_motif_db = (
Expand All @@ -41,13 +42,17 @@ def get_meme_file_path(meme_motif_db):
raise Exception(f"{meme_motif_db} is not a valid file.")


def get_default_config_file():
config = importlib_resources.files("grelu") / "resources" / "default_config.yaml"
assert config.exists()
return str(config)
def get_blacklist_file(genome: str) -> str:
"""
Return the path to a blacklist file
Args:
genome: Name of a genome whose blacklist file is included with gReLU.
Current name options are "hg19", "hg38" and "mm10".
def get_blacklist_file(genome):
Returns:
Path to the specified blacklist file.
"""
blacklist = (
importlib_resources.files("grelu")
/ "resources"
Expand All @@ -59,19 +64,46 @@ def get_blacklist_file(genome):
return str(blacklist)


def _check_wandb(host=DEFAULT_WANDB_HOST):
assert wandb.login(host=host), f'Weights & Biases (wandb) is not configured, see {DEFAULT_WANDB_HOST}/authorize'
def _check_wandb(host:str=DEFAULT_WANDB_HOST) -> None:
"""
Check that the user is logged into Weights and Biases
Args:
host: URL of the Weights & Biases host
"""
assert wandb.login(host=host, anonymous="allow"), f'Weights & Biases (wandb) is not configured, see {host}/authorize'


def projects(host: str=DEFAULT_WANDB_HOST) -> List[str]:
"""
List all projects in the model zoo
Args:
host: URL of the Weights & Biases host
def projects(host=DEFAULT_WANDB_HOST):
Returns:
List of project names
"""
_check_wandb(host=host)

api = wandb.Api()
projects = api.projects(DEFAULT_WANDB_ENTITY)
return [p.name for p in projects]


def artifacts(project, host=DEFAULT_WANDB_HOST, type_is=None, type_contains=None):
def artifacts(project: str, host: str=DEFAULT_WANDB_HOST, type_is: Optional[str]=None, type_contains: Optional[str]=None) -> List[str]:
"""
List all artifacts associated with a project in the model zoo
Args:
project: Name of the project to search
host: URL of the Weights & Biases host
type_is: Return only artifacts with this type
type_contains: Return only artifacts whose type contains this string
Returns:
List of artifact names
"""
_check_wandb(host)
project_path = f'{DEFAULT_WANDB_ENTITY}/{project}'

Expand All @@ -90,46 +122,126 @@ def artifacts(project, host=DEFAULT_WANDB_HOST, type_is=None, type_contains=None
return arts


def models(project, host=DEFAULT_WANDB_HOST):
def models(project:str, host:str=DEFAULT_WANDB_HOST) -> List[str]:
"""
List all models associated with a project in the model zoo
Args:
project: Name of the project to search
host: URL of the Weights & Biases host
Returns:
List of model names
"""
return artifacts(project, host=host, type_contains='model')


def datasets(project, host=DEFAULT_WANDB_HOST):
def datasets(project:str, host:str=DEFAULT_WANDB_HOST) -> List[str]:
"""
List all datasets associated with a project in the model zoo
Args:
project: Name of the project to search
host: URL of the Weights & Biases host
Returns:
List of dataset names
"""
return artifacts(project, host=host, type_contains='dataset')


def runs(project, host=DEFAULT_WANDB_HOST, field='id', filters=None):
def runs(project:str, host:str=DEFAULT_WANDB_HOST, field:str='id', filters: Optional[Dict[str, Any]]=None) -> List[str]:
"""
List attributes of all runs associated with a project in the model zoo
Args:
project: Name of the project to search
host: URL of the Weights & Biases host
field: Field to return from the run metadata
filters: Dictionary of filters to pass to `api.runs`
Returns:
List of run attributes
"""
_check_wandb(host=host)
project_path = f'{DEFAULT_WANDB_ENTITY}/{project}'

api = wandb.Api()
return [getattr(run, field) for run in api.runs(project_path, filters=filters)]


def get_artifact(name, project, alias='latest'):
_check_wandb()
def get_artifact(name:str, project:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest'):
"""
Retrieve an artifact associated with a project in the model zoo
Args:
name: Name of the artifact
project: Name of the project containing the artifact
host: URL of the Weights & Biases host
alias: Alias of the artifact
Returns:
The specific artifact
"""
_check_wandb(host=host)
project_path = f'{DEFAULT_WANDB_ENTITY}/{project}'

api = wandb.Api()
return api.artifact(f'{project_path}/{name}:{alias}')


def get_dataset_by_model(model_name, project, alias='latest'):
art = get_artifact(model_name, project, alias=alias)
def get_dataset_by_model(model_name:str, project:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest') -> List[str]:
"""
List all datasets associated with a model in the model zoo
Args:
model_name: Name of the model
project: Name of the project containing the model
host: URL of the Weights & Biases host
alias: Alias of the model artifact
Returns:
A list containing the names of all datasets linked to the model
"""
art = get_artifact(model_name, project, host=host, alias=alias)
run = art.logged_by()
return [x.name for x in run.used_artifacts()]


def get_model_by_dataset(dataset_name, project, alias='latest'):
art = get_artifact(dataset_name, project, alias=alias)
def get_model_by_dataset(dataset_name:str, project:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest') -> List[str]:
"""
List all models associated with a dataset in the model zoo
Args:
dataset_name: Name of the dataset
project: Name of the project containing the dataset
host: URL of the Weights & Biases host
alias: Alias of the dataset artifact
Returns:
A list containing the names of all models linked to the dataset
"""
art = get_artifact(dataset_name, project, host=host, alias=alias)
runs = art.used_by()
assert len(runs) > 0
return [x.name for x in runs[0].logged_artifacts()]


def load_model(project, model_name, alias='latest', checkpoint_file='model.ckpt'):
def load_model(project:str, model_name:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest', checkpoint_file:str='model.ckpt') -> LightningModel:
"""
Download and load a model from the model zoo
art = get_artifact(model_name, project, alias=alias)
Args:
project: Name of the project containing the model
model_name: Name of the model
host: URL of the Weights & Biases host
alias: Alias of the model artifact
checkpoint_file: Name of the checkpoint file contained in the model artifact
Returns:
A LightningModel object
"""
art = get_artifact(model_name, project, host=host, alias=alias)

with TemporaryDirectory() as d:
art.download(d)
Expand Down

0 comments on commit b7336fa

Please sign in to comment.