Skip to content

Commit

Permalink
Adding retrieval model to DSL
Browse files Browse the repository at this point in the history
  • Loading branch information
seanchatmangpt committed Mar 18, 2024
1 parent 5d172bc commit 5d908fc
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 73 deletions.
100 changes: 99 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry] # https://python-poetry.org/docs/pyproject/
name = "dspygen"
version = "2024.3.14.2"
version = "2024.3.17"
description = "A Ruby on Rails style framework for the DSPy (Demonstrate, Search, Predict) project for Language Models like GPT, BERT, and LLama."
authors = ["Sean Chatman <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -43,6 +43,7 @@ paho-mqtt = "^2.0.0"
psutil = "^5.9.8"
st-pages = "^0.4.5"
pykka = "^4.0.2"
ijson = "^3.2.3"

[tool.poetry.group.test.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/
coverage = { extras = ["toml"], version = ">=7.2.5" }
Expand Down
13 changes: 8 additions & 5 deletions src/dspygen/dsl/dsl_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dspygen.dsl.utils.dsl_language_model_utils import _get_language_model_instance
from dspygen.dsl.dsl_pydantic_models import PipelineDSLModel, LanguageModelConfig
from dspygen.dsl.utils.dsl_module_utils import _get_module_instance
from dspygen.dsl.utils.dsl_retrieval_model_utils import _get_retrieval_model_instance
from dspygen.dsl.utils.dsl_signature_utils import _create_signature_from_model
from dspygen.typetemp.functional import render

Expand Down Expand Up @@ -45,16 +46,18 @@ def _execute_step(pipeline, step):
Execute a step in a pipeline. Creates the LM, renders the args using Jinja2,
runs the module, and updates the context.
"""
if not pipeline.models:
pipeline.models = [LanguageModelConfig(label="default", name="OpenAI", args={})]
if not pipeline.lm_models:
pipeline.lm_models = [LanguageModelConfig(label="default", name="OpenAI", args={})]

rendered_args = {arg: render(str(value), **pipeline.context) for arg, value in step.args.items()}

module_inst = _get_module_instance(pipeline, rendered_args, step)

lm_inst = _get_language_model_instance(pipeline, step)

with dspy.context(lm=lm_inst):
rm_inst = _get_retrieval_model_instance(pipeline, step)

with dspy.context(lm=lm_inst, rm=rm_inst):
module_output = module_inst.forward(**rendered_args)

pipeline.context[step.module] = module_output
Expand Down Expand Up @@ -90,8 +93,8 @@ async def run_pipeline(request: PipelineRequest):


def main():
# context = execute_pipeline('examples/blog_pipeline.yaml')
context = execute_pipeline('/Users/candacechatman/dev/dspygen/pipeline.yaml', {"news": "$12,500 Retainer Contract"})
context = execute_pipeline('/Users/candacechatman/dev/dspygen/src/dspygen/dsl/examples/example_pipeline.yaml')
# context = execute_pipeline('/Users/candacechatman/dev/dspygen/pipeline.yaml', {"news": "$12,500 Retainer Contract"})
# context = execute_pipeline('examples/example_pipeline.yaml')

print(context)
Expand Down
Loading

0 comments on commit 5d908fc

Please sign in to comment.