Skip to content

Commit

Permalink
Reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexejPenner committed Nov 2, 2024
1 parent 0cd7be8 commit a0535df
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 370 deletions.
61 changes: 29 additions & 32 deletions llm-complete-guide/gh_action_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@

import click
import yaml
from pipelines.llm_basic_rag import llm_basic_rag
from zenml.client import Client
from zenml.exceptions import ZenKeyError

from pipelines.llm_basic_rag import llm_basic_rag


@click.command(
help="""
Expand All @@ -39,7 +38,6 @@
default=False,
help="Disable cache.",
)

@click.option(
"--create-template",
"create_template",
Expand All @@ -51,26 +49,26 @@
"--config",
"config",
default="rag_local_dev.yaml",
help="Specify a configuration file"
help="Specify a configuration file",
)
@click.option(
"--service-account-id",
"service_account_id",
default=None,
help="Specify a service account ID"
help="Specify a service account ID",
)
@click.option(
"--event-source-id",
"event_source_id",
default=None,
help="Specify an event source ID"
help="Specify an event source ID",
)
def main(
no_cache: bool = False,
config: Optional[str]= "rag_local_dev.yaml",
config: Optional[str] = "rag_local_dev.yaml",
create_template: bool = False,
service_account_id: Optional[str] = None,
event_source_id: Optional[str] = None
event_source_id: Optional[str] = None,
):
"""
Executes the pipeline to train a basic RAG model.
Expand All @@ -86,43 +84,43 @@ def main(
client = Client()
config_path = Path(__file__).parent / "configs" / config

with (open(config_path,"r") as file):
with open(config_path, "r") as file:
config = yaml.safe_load(file)

if create_template:

# run pipeline
run = llm_basic_rag.with_options(
config_path=str(config_path),
enable_cache=not no_cache
config_path=str(config_path), enable_cache=not no_cache
)()
# create new run template
rt = client.create_run_template(
name=f"production-llm-complete-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
deployment_id=run.deployment_id
deployment_id=run.deployment_id,
)

try:
# Check if an action ahs already be configured for this pipeline
action = client.get_action(
name_id_or_prefix="LLM Complete (production)",
allow_name_prefix_match=True
allow_name_prefix_match=True,
)
except ZenKeyError:
if not event_source_id:
raise RuntimeError("An event source is required for this workflow.")
raise RuntimeError(
"An event source is required for this workflow."
)

if not service_account_id:
service_account_id = client.create_service_account(
name="github-action-sa",
description="To allow triggered pipelines to run with M2M authentication."
description="To allow triggered pipelines to run with M2M authentication.",
).id

action_id = client.create_action(
name="LLM Complete (production)",
configuration={
"template_id": str(rt.id),
"run_config": pop_restricted_configs(config)
"run_config": pop_restricted_configs(config),
},
service_account_id=service_account_id,
auth_window=0,
Expand All @@ -132,7 +130,7 @@ def main(
event_source_id=UUID(event_source_id),
event_filter={"event_type": "tag_event"},
action_id=action_id,
description="Trigger pipeline to reindex everytime the docs are updated through git."
description="Trigger pipeline to reindex everytime the docs are updated through git.",
)
else:
# update the action with the new template
Expand All @@ -141,14 +139,13 @@ def main(
name_id_or_prefix=action.id,
configuration={
"template_id": str(rt.id),
"run_config": pop_restricted_configs(config)
}
"run_config": pop_restricted_configs(config),
},
)

else:
llm_basic_rag.with_options(
config_path=str(config_path),
enable_cache=not no_cache
config_path=str(config_path), enable_cache=not no_cache
)()


Expand All @@ -162,22 +159,22 @@ def pop_restricted_configs(run_configuration: dict) -> dict:
Modified dictionary with restricted items removed
"""
# Pop top-level restricted items
run_configuration.pop('parameters', None)
run_configuration.pop('build', None)
run_configuration.pop('schedule', None)
run_configuration.pop("parameters", None)
run_configuration.pop("build", None)
run_configuration.pop("schedule", None)

# Pop docker settings if they exist
if 'settings' in run_configuration:
run_configuration['settings'].pop('docker', None)
if "settings" in run_configuration:
run_configuration["settings"].pop("docker", None)

# Pop docker settings from steps if they exist
if 'steps' in run_configuration:
for step in run_configuration['steps'].values():
if 'settings' in step:
step['settings'].pop('docker', None)
if "steps" in run_configuration:
for step in run_configuration["steps"].values():
if "settings" in step:
step["settings"].pop("docker", None)

return run_configuration


if __name__ == "__main__":
main()
main()
Loading

0 comments on commit a0535df

Please sign in to comment.