-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from experimaestro/features/experiments
Features/experiments
- Loading branch information
Showing
6 changed files
with
353 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .cli import experiments_cli, ExperimentHelper, ExperimentCallable # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import inspect | ||
import json | ||
import logging | ||
import sys | ||
from pathlib import Path | ||
from typing import Any, List, Optional, Protocol, Tuple | ||
from functools import cached_property | ||
|
||
import click | ||
import omegaconf | ||
import yaml | ||
from experimaestro import LauncherRegistry, RunMode, experiment | ||
from experimaestro.settings import get_workspace | ||
from omegaconf import OmegaConf, SCMode | ||
from termcolor import cprint | ||
|
||
from xpmir.learning.optim import TensorboardService | ||
|
||
|
||
class ExperimentHelper: | ||
"""Helper for experiments""" | ||
|
||
xp: experiment | ||
callable: "ExperimentCallable" | ||
|
||
def __init__(self, callable: "ExperimentCallable"): | ||
self.callable = callable | ||
|
||
@cached_property | ||
def tensorboard_service(self): | ||
return self.xp.add_service(TensorboardService(self.xp.resultspath / "runs")) | ||
|
||
"""Handles extra arguments""" | ||
|
||
def run(self, args: List[str], configuration: Any): | ||
assert len(args) == 0 | ||
self.callable.run(self, configuration) | ||
|
||
|
||
class ExperimentCallable(Protocol): | ||
def __call__(self, helper: ExperimentHelper, configuration: Any): | ||
... | ||
|
||
|
||
def load(yaml_file: Path): | ||
"""Loads a YAML file, and parents one if they exist""" | ||
if not yaml_file.exists() and yaml_file.suffix != ".yaml": | ||
yaml_file = yaml_file.with_suffix(".yaml") | ||
|
||
with yaml_file.open("rt") as fp: | ||
_data = yaml.full_load(fp) | ||
data = [_data] | ||
if parent := _data.get("parent", None): | ||
data.extend(load(yaml_file.parent / parent)) | ||
del _data["parent"] | ||
|
||
return data | ||
|
||
|
||
@click.option("--debug", is_flag=True, help="Print debug information") | ||
@click.option("--show", is_flag=True, help="Print configuration and exits") | ||
@click.option( | ||
"--env", | ||
help="Define one environment variable", | ||
type=(str, str), | ||
multiple=True, | ||
) | ||
@click.option( | ||
"--host", | ||
type=str, | ||
default=None, | ||
help="Server hostname (default to localhost," | ||
" not suitable if your jobs are remote)", | ||
) | ||
@click.option( | ||
"--run-mode", | ||
type=click.Choice(RunMode), | ||
default=RunMode.NORMAL, | ||
help="Sets the run mode", | ||
) | ||
@click.option( | ||
"--xpm-config-dir", | ||
type=Path, | ||
default=None, | ||
help="Path for the experimaestro config directory " | ||
"(if not specified, use $HOME/.config/experimaestro)", | ||
) | ||
@click.option( | ||
"--port", | ||
type=int, | ||
default=None, | ||
help="Port for monitoring (can be defined in the settings.yaml file)", | ||
) | ||
@click.option( | ||
"--file", "xp_file", help="The file containing the main experimental code" | ||
) | ||
@click.option( | ||
"--workdir", | ||
type=str, | ||
default=None, | ||
help="Working directory - if None, uses the default XPM " "working directory", | ||
) | ||
@click.option("--conf", "-c", "extra_conf", type=str, multiple=True) | ||
@click.argument("args", nargs=-1, type=click.UNPROCESSED) | ||
@click.argument("yaml_file", metavar="YAML file", type=str) | ||
@click.command() | ||
def experiments_cli( | ||
yaml_file: List[str], | ||
xp_file: str, | ||
host: str, | ||
port: int, | ||
xpm_config_dir: Path, | ||
workdir: Optional[Path], | ||
env: List[Tuple[str, str]], | ||
run_mode: RunMode, | ||
extra_conf: List[str], | ||
args: List[str], | ||
show: bool, | ||
debug: bool, | ||
): | ||
"""Run an experiment""" | ||
# --- Set the logger | ||
logging.getLogger().setLevel(logging.DEBUG if debug else logging.INFO) | ||
logging.getLogger("xpm.hash").setLevel(logging.INFO) | ||
|
||
# --- Loads the YAML | ||
yamls = load(Path(yaml_file)) | ||
|
||
# --- Get the XP file | ||
if xp_file is None: | ||
for data in yamls: | ||
if xp_file := data.get("file"): | ||
del data["file"] | ||
break | ||
|
||
if xp_file is None: | ||
raise ValueError("No experiment file given") | ||
|
||
# --- Set some options | ||
|
||
if xpm_config_dir is not None: | ||
assert xpm_config_dir.is_dir() | ||
LauncherRegistry.set_config_dir(xpm_config_dir) | ||
|
||
# --- Loads the XP file | ||
xp_file = Path(xp_file) | ||
if not xp_file.exists() and xp_file.suffix != ".py": | ||
xp_file = xp_file.with_suffix(".py") | ||
xp_file = Path(yaml_file).parent / xp_file | ||
|
||
with open(xp_file, "r") as f: | ||
source = f.read() | ||
if sys.version_info < (3, 9): | ||
the__file__ = str(xp_file) | ||
else: | ||
the__file__ = str(xp_file.absolute()) | ||
|
||
code = compile(source, filename=the__file__, mode="exec") | ||
_locals = {} | ||
|
||
sys.path.append(str(xp_file.parent.absolute())) | ||
try: | ||
exec(code, _locals, _locals) | ||
finally: | ||
sys.path.pop() | ||
|
||
# --- ... and runs it | ||
helper = _locals.get("run", None) | ||
if helper is None: | ||
raise ValueError(f"Could not find run function in {the__file__}") | ||
|
||
if not isinstance(helper, ExperimentHelper): | ||
helper = ExperimentHelper(helper) | ||
|
||
parameters = inspect.signature(helper.callable).parameters | ||
list_parameters = list(parameters.values()) | ||
assert len(list_parameters) == 2, ( | ||
"Callable function should only " | ||
f"have two arguments (got {len(list_parameters)})" | ||
) | ||
|
||
schema = list_parameters[1].annotation | ||
omegaconf_schema = OmegaConf.structured(schema()) | ||
|
||
configuration = OmegaConf.merge(*yamls) | ||
if extra_conf: | ||
configuration.merge_with(OmegaConf.from_dotlist(extra_conf)) | ||
if omegaconf_schema is not None: | ||
try: | ||
configuration = OmegaConf.merge(omegaconf_schema, configuration) | ||
except omegaconf.errors.ConfigKeyError as e: | ||
cprint(f"Error in configuration:\n\n{e}", "red", file=sys.stderr) | ||
sys.exit(1) | ||
|
||
# Move to an object container | ||
configuration: schema = OmegaConf.to_container( | ||
configuration, structured_config_mode=SCMode.INSTANTIATE | ||
) | ||
|
||
if show: | ||
print(json.dumps(OmegaConf.to_container(configuration))) # noqa: T201 | ||
sys.exit(0) | ||
|
||
# Get the working directory | ||
if workdir is None or not Path(workdir).is_dir(): | ||
workdir = get_workspace(workdir).path.expanduser().resolve() | ||
logging.info("Using working directory %s", workdir) | ||
|
||
# --- Runs the experiment | ||
with experiment( | ||
workdir, configuration.id, host=host, port=port, run_mode=run_mode | ||
) as xp: | ||
# Set up the environment | ||
for key, value in env: | ||
xp.setenv(key, value) | ||
|
||
# Run the experiment | ||
helper.xp = xp | ||
helper.run(list(args), configuration) | ||
|
||
# ... and wait | ||
xp.wait() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from typing import Any, List, Optional, Dict | ||
import logging | ||
from pathlib import Path | ||
import click | ||
import inspect | ||
import io | ||
|
||
import docstring_parser | ||
from experimaestro import RunMode, Config | ||
from xpmir.evaluation import EvaluationsCollection | ||
from xpmir.models import XPMIRHFHub | ||
from xpmir.papers.results import PaperResults | ||
from xpmir.experiments import ExperimentHelper | ||
|
||
|
||
class UploadToHub: | ||
def __init__(self, model_id: Optional[str], doc: docstring_parser.Docstring): | ||
self.model_id = model_id | ||
self.doc = doc | ||
|
||
def send_scorer( | ||
self, | ||
models: Dict[str, Config], | ||
*, | ||
evaluations: Optional[EvaluationsCollection] = None, | ||
tb_logs: Dict[str, Path], | ||
add_des: str = "", | ||
): | ||
"""Upload the scorer(s) to the HuggingFace Hub | ||
:param models: The models to upload, each with a key | ||
:param tb_logs: The tensorboard logs | ||
:param evaluations: Models evaluations, defaults to None | ||
:param add_des: Extra documentation, defaults to "" | ||
""" | ||
if self.model_id is None: | ||
return | ||
|
||
out = io.StringIO() | ||
out.write( | ||
"""--- | ||
library_name: xpmir | ||
--- | ||
""" | ||
) | ||
|
||
out.write(f"{self.doc}\n\n") | ||
out.write(f"{add_des}\n\n") | ||
|
||
out.write("\n## Using the model") | ||
out.write( | ||
f""" | ||
The model can be loaded with [experimaestro | ||
IR](https://experimaestro-ir.readthedocs.io/en/latest/) | ||
```py from xpmir.models import AutoModel | ||
from xpmir.models import AutoModel | ||
# Model that can be re-used in experiments | ||
model, init_tasks = AutoModel.load_from_hf_hub("{self.model_id}") | ||
# Use this if you want to actually use the model | ||
model = AutoModel.load_from_hf_hub("{self.model_id}", as_instance=True) | ||
model.rsv("walgreens store sales average", "The average Walgreens salary ranges...") | ||
``` | ||
""" | ||
) | ||
|
||
assert len(models) == 1, "Cannot deal with more than one variant" | ||
((key, model),) = list(models.items()) | ||
|
||
if evaluations is not None: | ||
out.write("\n## Results\n\n") | ||
evaluations.output_model_results(key, file=out) | ||
|
||
readme_md = out.getvalue() | ||
|
||
logging.info("Uploading to HuggingFace Hub") | ||
XPMIRHFHub(model, readme=readme_md, tb_logs=tb_logs).push_to_hub( | ||
repo_id=self.model_id, config={} | ||
) | ||
|
||
|
||
class IRExperimentHelper(ExperimentHelper): | ||
def run(self, extra_args: List[str], configuration: Any): | ||
@click.option("--upload-to-hub", type=str) | ||
@click.command() | ||
def cli(upload_to_hub: str): | ||
results = self.callable(self, configuration) | ||
self.xp.wait() | ||
|
||
if isinstance(results, PaperResults) and self.xp.run_mode == RunMode.NORMAL: | ||
if upload_to_hub is not None: | ||
if configuration.title == "" and configuration.description == "": | ||
doc = docstring_parser.parse(self.callable.__doc__) | ||
else: | ||
doc = f"# {configuration.title}\n{configuration.description}" | ||
upload = UploadToHub(upload_to_hub, doc) | ||
|
||
upload.send_scorer( | ||
results.models, | ||
evaluations=results.evaluations, | ||
tb_logs=results.tb_logs, | ||
) | ||
|
||
print(results.evaluations.to_dataframe()) # noqa: T201 | ||
|
||
return cli(extra_args, standalone_mode=False) | ||
|
||
|
||
def ir_experiment(*args, **kwargs): | ||
"""Wraps an experiment into an IR experiment | ||
1. Upload to github (if requested) | ||
2. Print the results | ||
""" | ||
if len(args) == 1 and len(kwargs) == 0 and inspect.isfunction(args[0]): | ||
return IRExperimentHelper(callable) | ||
|
||
def wrapper(callable): | ||
return IRExperimentHelper(callable) | ||
|
||
return wrapper |
Oops, something went wrong.