Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(evals): Make context support any type. #1517

Merged
merged 10 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion genkit-tools/cli/src/commands/eval-extract-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export const evalExtractData = new Command('eval:extractData')
testCaseId: generateTestCaseId(),
input: extractors.input(trace),
output: extractors.output(trace),
context: JSON.parse(extractors.context(trace)) as string[],
context: toArray(extractors.context(trace)),
// The trace (t) does not contain the traceId, so we have to pull it out of the
// spans, de- dupe, and turn it back into an array.
traceIds: Array.from(
Expand Down Expand Up @@ -105,3 +105,7 @@ export const evalExtractData = new Command('eval:extractData')
}
});
});

function toArray(input: any) {
return Array.isArray(input) ? input : [input];
}
7 changes: 3 additions & 4 deletions genkit-tools/common/src/eval/evaluate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async function gatherEvalInput(params: {
input,
output,
error,
context: JSON.parse(context) as string[],
context: Array.isArray(context) ? context : [context],
reference: state.reference,
traceIds: [traceId],
};
Expand All @@ -395,12 +395,11 @@ function getSpanErrorMessage(span: SpanData): string | undefined {
}
}

function getErrorFromModelResponse(output: string): string | undefined {
const obj = JSON.parse(output);
function getErrorFromModelResponse(obj: any): string | undefined {
const response = GenerateResponseSchema.parse(obj);

if (!response || !response.candidates || response.candidates.length === 0) {
return `No response was extracted from the output. '${output}'`;
return `No response was extracted from the output. '${JSON.stringify(obj)}'`;
}

// We currently only support the first candidate
Expand Down
2 changes: 1 addition & 1 deletion genkit-tools/common/src/plugin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const EvaluationExtractorSchema = z.record(
z.union([
z.string(), // specify the displayName (default to output)
StepSelectorSchema, //, {inputOf: 'my-step-name'}
z.function().args(TraceDataSchema).returns(z.string()), // custom trace extractor
z.function().args(TraceDataSchema).returns(z.any()), // custom trace extractor
])
);
export type EvaluationExtractor = z.infer<typeof EvaluationExtractorSchema>;
Expand Down
2 changes: 1 addition & 1 deletion genkit-tools/common/src/types/eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export const EvalInputSchema = z.object({
input: z.any(),
output: z.any(),
error: z.string().optional(),
context: z.array(z.string()).optional(),
context: z.array(z.any()).optional(),
reference: z.any().optional(),
traceIds: z.array(z.string()),
});
Expand Down
61 changes: 34 additions & 27 deletions genkit-tools/common/src/utils/eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import { NestedSpanData, TraceData } from '../types/trace';
import { logger } from './logger';
import { stackTraceSpans } from './trace';

export type EvalExtractorFn = (t: TraceData) => string;
const JSON_EMPTY_STRING = '""';
export type EvalExtractorFn = (t: TraceData) => any;

export const EVALUATOR_ACTION_PREFIX = '/evaluator';

Expand Down Expand Up @@ -78,30 +77,39 @@ function getRootSpan(trace: TraceData): NestedSpanData | undefined {
return stackTraceSpans(trace);
}

function safeParse(value?: string) {
if (value) {
try {
return JSON.parse(value);
} catch (e) {
return '';
}
}
return '';
}

const DEFAULT_INPUT_EXTRACTOR: EvalExtractorFn = (trace: TraceData) => {
const rootSpan = getRootSpan(trace);
return (rootSpan?.attributes['genkit:input'] as string) || JSON_EMPTY_STRING;
return safeParse(rootSpan?.attributes['genkit:input'] as string);
};
const DEFAULT_OUTPUT_EXTRACTOR: EvalExtractorFn = (trace: TraceData) => {
const rootSpan = getRootSpan(trace);
return (rootSpan?.attributes['genkit:output'] as string) || JSON_EMPTY_STRING;
return safeParse(rootSpan?.attributes['genkit:output'] as string);
};
const DEFAULT_CONTEXT_EXTRACTOR: EvalExtractorFn = (trace: TraceData) => {
return JSON.stringify(
Object.values(trace.spans)
.filter((s) => s.attributes['genkit:metadata:subtype'] === 'retriever')
.flatMap((s) => {
const output: RetrieverResponse = JSON.parse(
s.attributes['genkit:output'] as string
);
if (!output) {
return [];
}
return output.documents.flatMap((d: DocumentData) =>
d.content.map((c) => c.text).filter((text): text is string => !!text)
);
})
);
return Object.values(trace.spans)
.filter((s) => s.attributes['genkit:metadata:subtype'] === 'retriever')
.flatMap((s) => {
const output: RetrieverResponse = safeParse(
s.attributes['genkit:output'] as string
);
if (!output) {
return [];
}
return output.documents.flatMap((d: DocumentData) =>
d.content.map((c) => c.text).filter((text): text is string => !!text)
);
});
};

const DEFAULT_FLOW_EXTRACTORS: Record<EvalField, EvalExtractorFn> = {
Expand All @@ -113,29 +121,29 @@ const DEFAULT_FLOW_EXTRACTORS: Record<EvalField, EvalExtractorFn> = {
const DEFAULT_MODEL_EXTRACTORS: Record<EvalField, EvalExtractorFn> = {
input: DEFAULT_INPUT_EXTRACTOR,
output: DEFAULT_OUTPUT_EXTRACTOR,
context: () => JSON.stringify([]),
context: () => [],
};

function getStepAttribute(
trace: TraceData,
stepName: string,
attributeName?: string
): string {
) {
// Default to output
const attr = attributeName ?? 'genkit:output';
const values = Object.values(trace.spans)
.filter((step) => step.displayName === stepName)
.flatMap((step) => {
return JSON.parse(step.attributes[attr] as string);
return safeParse(step.attributes[attr] as string);
});
if (values.length === 0) {
return JSON_EMPTY_STRING;
return '';
}
if (values.length === 1) {
return JSON.stringify(values[0]);
return values[0];
}
// Return array if multiple steps have the same name
return JSON.stringify(values);
return values;
}

function getExtractorFromStepName(stepName: string): EvalExtractorFn {
Expand All @@ -159,7 +167,7 @@ function getExtractorFromStepSelector(
selectedAttribute = 'genkit:output';
}
if (!stepName) {
return JSON_EMPTY_STRING;
return '';
} else {
return getStepAttribute(trace, stepName, selectedAttribute);
}
Expand Down Expand Up @@ -196,7 +204,6 @@ export async function getEvalExtractors(
return Promise.resolve(DEFAULT_MODEL_EXTRACTORS);
}
const config = await findToolsConfig();
logger.info(`Found tools config... ${JSON.stringify(config)}`);
const extractors = config?.evaluators
?.filter((e) => e.actionRef === actionRef)
.map((e) => e.extractors);
Expand Down
76 changes: 33 additions & 43 deletions genkit-tools/common/tests/utils/eval_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ describe('eval utils', () => {
expect(Object.keys(extractors).sort()).toEqual(
['input', 'output', 'context'].sort()
);
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
expect(extractors.output(trace)).toEqual(JSON.stringify('My output'));
expect(extractors.context(trace)).toEqual(JSON.stringify([]));
expect(extractors.input(trace)).toEqual('My input');
expect(extractors.output(trace)).toEqual('My output');
expect(extractors.context(trace)).toEqual([]);
});
});

Expand All @@ -63,9 +63,9 @@ describe('eval utils', () => {
expect(Object.keys(extractors).sort()).toEqual(
['input', 'output', 'context'].sort()
);
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
expect(extractors.output(trace)).toEqual(JSON.stringify('My output'));
expect(extractors.context(trace)).toEqual(JSON.stringify(CONTEXT_TEXTS));
expect(extractors.input(trace)).toEqual('My input');
expect(extractors.output(trace)).toEqual('My output');
expect(extractors.context(trace)).toEqual(CONTEXT_TEXTS);
});

it('returns custom extractors by stepName', async () => {
Expand Down Expand Up @@ -100,11 +100,9 @@ describe('eval utils', () => {

const extractors = await getEvalExtractors('/flow/multiSteps');

expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
expect(extractors.output(trace)).toEqual(
JSON.stringify({ out: 'my-object-output' })
);
expect(extractors.context(trace)).toEqual(JSON.stringify(CONTEXT_TEXTS));
expect(extractors.input(trace)).toEqual('My input');
expect(extractors.output(trace)).toEqual({ out: 'my-object-output' });
expect(extractors.context(trace)).toEqual(CONTEXT_TEXTS);
});

it('returns custom extractors by stepSelector', async () => {
Expand Down Expand Up @@ -146,11 +144,9 @@ describe('eval utils', () => {

const extractors = await getEvalExtractors('/flow/multiSteps');

expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
expect(extractors.output(trace)).toEqual(JSON.stringify('step2-input'));
expect(extractors.context(trace)).toEqual(
JSON.stringify(['Hello', 'World'])
);
expect(extractors.input(trace)).toEqual('My input');
expect(extractors.output(trace)).toEqual('step2-input');
expect(extractors.context(trace)).toEqual(['Hello', 'World']);
});

it('returns custom extractors by trace function', async () => {
Expand All @@ -160,23 +156,21 @@ describe('eval utils', () => {
actionRef: '/flow/multiSteps',
extractors: {
input: (trace: TraceData) => {
return JSON.stringify(
Object.values(trace.spans)
.filter(
(s) =>
s.attributes['genkit:type'] === 'action' &&
s.attributes['genkit:metadata:subtype'] !== 'retriever'
)
.map((s) => {
const inputValue = JSON.parse(
s.attributes['genkit:input'] as string
).start.input;
if (!inputValue) {
return '';
}
return inputValue + ' TEST TEST TEST';
})
);
return Object.values(trace.spans)
.filter(
(s) =>
s.attributes['genkit:type'] === 'action' &&
s.attributes['genkit:metadata:subtype'] !== 'retriever'
)
.map((s) => {
const inputValue = JSON.parse(
s.attributes['genkit:input'] as string
).start.input;
if (!inputValue) {
return '';
}
return inputValue + ' TEST TEST TEST';
});
},
output: { inputOf: 'step2' },
context: { outputOf: 'step3-array' },
Expand Down Expand Up @@ -211,13 +205,9 @@ describe('eval utils', () => {

const extractors = await getEvalExtractors('/flow/multiSteps');

expect(extractors.input(trace)).toEqual(
JSON.stringify(['My input TEST TEST TEST'])
);
expect(extractors.output(trace)).toEqual(JSON.stringify('step2-input'));
expect(extractors.context(trace)).toEqual(
JSON.stringify(['Hello', 'World'])
);
expect(extractors.input(trace)).toEqual(['My input TEST TEST TEST']);
expect(extractors.output(trace)).toEqual('step2-input');
expect(extractors.context(trace)).toEqual(['Hello', 'World']);
});

it('returns runs default extractors when trace fails', async () => {
Expand All @@ -239,8 +229,8 @@ describe('eval utils', () => {
expect(Object.keys(extractors).sort()).toEqual(
['input', 'output', 'context'].sort()
);
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
expect(extractors.output(trace)).toEqual(JSON.stringify(''));
expect(extractors.context(trace)).toEqual(JSON.stringify(CONTEXT_TEXTS));
expect(extractors.input(trace)).toEqual('My input');
expect(extractors.output(trace)).toEqual('');
expect(extractors.context(trace)).toEqual(CONTEXT_TEXTS);
});
});
34 changes: 24 additions & 10 deletions js/plugins/evaluators/src/metrics/answer_relevancy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import { getDirName } from './helper.js';

const AnswerRelevancyResponseSchema = z.object({
question: z.string(),
answered: z.literal(0).or(z.literal(1)),
noncommittal: z.literal(0).or(z.literal(1)),
answered: z.enum(['0', '1'] as const),
noncommittal: z.enum(['0', '1'] as const),
});

export async function answerRelevancyScore<
Expand All @@ -40,12 +40,26 @@ export async function answerRelevancyScore<
embedderOptions?: z.infer<CustomEmbedderOptions>
): Promise<Score> {
try {
if (!dataPoint.context?.length) {
throw new Error('Context was not provided');
if (!dataPoint.input) {
throw new Error('Input was not provided');
}
if (!dataPoint.output) {
throw new Error('Output was not provided');
}
if (!dataPoint.context?.length) {
throw new Error('Context was not provided');
}

const input =
typeof dataPoint.input === 'string'
? dataPoint.input
: JSON.stringify(dataPoint.input);
const output =
typeof dataPoint.output === 'string'
? dataPoint.output
: JSON.stringify(dataPoint.output);
const context = dataPoint.context.map((i) => JSON.stringify(i));

const prompt = await loadPromptFile(
ai.registry,
path.resolve(getDirName(), '../../prompts/answer_relevancy.prompt')
Expand All @@ -54,9 +68,9 @@ export async function answerRelevancyScore<
model: judgeLlm,
config: judgeConfig,
prompt: prompt.renderText({
question: dataPoint.input as string,
answer: dataPoint.output as string,
context: dataPoint.context.join(' '),
question: input,
answer: output,
context: context.join(' '),
}),
output: {
schema: AnswerRelevancyResponseSchema,
Expand All @@ -68,7 +82,7 @@ export async function answerRelevancyScore<

const questionEmbed = await ai.embed({
embedder,
content: dataPoint.input as string,
content: input,
options: embedderOptions,
});
const genQuestionEmbed = await ai.embed({
Expand All @@ -77,8 +91,8 @@ export async function answerRelevancyScore<
options: embedderOptions,
});
const score = cosineSimilarity(questionEmbed, genQuestionEmbed);
const answered = response.output?.answered === 1;
const isNonCommittal = response.output?.noncommittal === 1;
const answered = response.output?.answered === '1' ? 1 : 0;
const isNonCommittal = response.output?.noncommittal === '1' ? 1 : 0;
const answeredPenalty = !answered ? 0.5 : 0;
const adjustedScore =
score - answeredPenalty < 0 ? 0 : score - answeredPenalty;
Expand Down
Loading
Loading