-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Structured Output #1443
base: main
Are you sure you want to change the base?
Add Structured Output #1443
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅ 📢 Thoughts on this report? Let us know! |
a7ed072
to
f74e2e6
Compare
e96721f
to
d2c0827
Compare
d2c0827
to
d07d699
Compare
@dylanholmes @vachillo still working through docs and some final cleanup, but the PR is certainly developed/large enough for first pass of reviews. |
def _add_native_schema_to_prompt_stack(self, stack: PromptStack, rulesets: list[Ruleset]) -> None: | ||
# Need to separate JsonSchemaRules from other rules, removing them in the process | ||
json_schema_rules = [rule for ruleset in rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule)] | ||
non_json_schema_rules = [ | ||
[rule for rule in ruleset.rules if not isinstance(rule, JsonSchemaRule)] for ruleset in rulesets | ||
] | ||
for ruleset, non_json_rules in zip(rulesets, non_json_schema_rules): | ||
ruleset.rules = non_json_rules | ||
|
||
schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)] | ||
|
||
if len(json_schema_rules) != len(schemas): | ||
warnings.warn( | ||
"Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`.", | ||
stacklevel=2, | ||
) | ||
|
||
if schemas: | ||
stack.output_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really don't like this method. Very open to suggestions on improvements.
system_template = self.generate_system_template(self) | ||
if system_template: | ||
stack.add_system_message(system_template) | ||
rulesets = self.rulesets | ||
system_artifacts = [TextArtifact(self.generate_system_template(self))] | ||
if self.prompt_driver.use_native_structured_output: | ||
self._add_native_schema_to_prompt_stack(stack, rulesets) | ||
|
||
# Ensure there is at least one Ruleset that has non-empty `rules`. | ||
if any(len(ruleset.rules) for ruleset in rulesets): | ||
system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets))) | ||
|
||
# Ensure there is at least one system Artifact that has a non-empty value. | ||
has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts) | ||
if has_system_artifacts: | ||
stack.add_system_message(ListArtifact(system_artifacts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Describe your changes
Issue ticket number and link
Closes #1468
Closes #1467