Skip to content

Commit

Permalink
Add Vertex AI fluency metric (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
tagboola authored May 2, 2024
1 parent 2afc739 commit 5a6fa82
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 34 deletions.
5 changes: 3 additions & 2 deletions docs/plugins/vertex-ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ through the [Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/doc

It also provides access to subset of evaluation metrics through the Vertex AI [Rapid Evaluation API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation).

- [BLEU](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#bleuinput)
- [ROUGE](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#rougeinput)
- [Fluency](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#fluencyinput)
- [Safety](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#safetyinput)
- [Groundeness](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#groundednessinput)
- [ROUGE](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#rougeinput)
- [BLEU](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#bleuinput)

## Installation

Expand Down
83 changes: 57 additions & 26 deletions js/plugins/vertexai/src/evaluation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/

import { BaseDataPoint } from '@genkit-ai/ai/evaluator';
import { Action } from '@genkit-ai/core';
import { GoogleAuth } from 'google-auth-library';
import { JSONClient } from 'google-auth-library/build/src/auth/googleauth';
Expand All @@ -27,10 +26,11 @@ import { EvaluatorFactory } from './evaluator_factory';
*/
export enum VertexAIEvaluationMetricType {
// Update genkit/docs/plugins/vertex-ai.md when modifying the list of enums
SAFETY = 'SAFETY',
GROUNDEDNESS = 'GROUNDEDNESS',
BLEU = 'BLEU',
ROUGE = 'ROUGE',
FLUENCY = 'FLEUNCY',
SAFETY = 'SAFETY',
GROUNDEDNESS = 'GROUNDEDNESS',
}

/**
Expand Down Expand Up @@ -66,6 +66,9 @@ export function vertexEvaluators(
case VertexAIEvaluationMetricType.ROUGE: {
return createRougeEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.FLUENCY: {
return createFluencyEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.SAFETY: {
return createSafetyEvaluator(factory, metricSpec);
}
Expand Down Expand Up @@ -118,12 +121,9 @@ function createBleuEvaluator(
},
};
},
(response, datapoint) => {
(response) => {
return {
testCaseId: datapoint.testCaseId,
evaluation: {
score: response.bleuResults.bleuMetricValues[0].score,
},
score: response.bleuResults.bleuMetricValues[0].score,
};
}
);
Expand Down Expand Up @@ -163,11 +163,48 @@ function createRougeEvaluator(
},
};
},
(response, datapoint) => {
(response) => {
return {
testCaseId: datapoint.testCaseId,
evaluation: {
score: response.rougeResults.rougeMetricValues[0].score,
score: response.rougeResults.rougeMetricValues[0].score,
};
}
);
}

const FluencyResponseSchema = z.object({
fluencyResult: z.object({
score: z.number(),
explanation: z.string(),
confidence: z.number(),
}),
});

function createFluencyEvaluator(
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
{
metric: VertexAIEvaluationMetricType.FLUENCY,
displayName: 'Fluency',
definition: 'Assesses the language mastery of an output',
responseSchema: FluencyResponseSchema,
},
(datapoint) => {
return {
fluencyInput: {
metricSpec,
instance: {
prediction: datapoint.output as string,
},
},
};
},
(response) => {
return {
score: response.fluencyResult.score,
details: {
reasoning: response.fluencyResult.explanation,
},
};
}
Expand Down Expand Up @@ -203,14 +240,11 @@ function createSafetyEvaluator(
},
};
},
(response, datapoint: BaseDataPoint) => {
(response) => {
return {
testCaseId: datapoint.testCaseId,
evaluation: {
score: response.safetyResult?.score,
details: {
reasoning: response.safetyResult?.explanation,
},
score: response.safetyResult.score,
details: {
reasoning: response.safetyResult.explanation,
},
};
}
Expand Down Expand Up @@ -248,14 +282,11 @@ function createGroundednessEvaluator(
},
};
},
(response, datapoint: BaseDataPoint) => {
(response) => {
return {
testCaseId: datapoint.testCaseId,
evaluation: {
score: response.groundednessResult?.score,
details: {
reasoning: response.groundednessResult?.explanation,
},
score: response.groundednessResult.score,
details: {
reasoning: response.groundednessResult.explanation,
},
};
}
Expand Down
12 changes: 6 additions & 6 deletions js/plugins/vertexai/src/evaluator_factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { BaseDataPoint, defineEvaluator } from '@genkit-ai/ai/evaluator';
import { BaseDataPoint, defineEvaluator, Score } from '@genkit-ai/ai/evaluator';
import { Action, GENKIT_CLIENT_HEADER } from '@genkit-ai/core';
import { runInNewSpan } from '@genkit-ai/core/tracing';
import { GoogleAuth } from 'google-auth-library';
Expand All @@ -37,10 +37,7 @@ export class EvaluatorFactory {
responseSchema: ResponseType;
},
toRequest: (datapoint: BaseDataPoint) => any,
responseHandler: (
response: z.infer<ResponseType>,
datapoint: BaseDataPoint
) => any
responseHandler: (response: z.infer<ResponseType>) => Score
): Action {
return defineEvaluator(
{
Expand All @@ -55,7 +52,10 @@ export class EvaluatorFactory {
responseSchema
);

return responseHandler(response, datapoint);
return {
evaluation: responseHandler(response),
testCaseId: datapoint.testCaseId,
};
}
);
}
Expand Down

0 comments on commit 5a6fa82

Please sign in to comment.