From ecf9a29ad5a8f31307ffdd1498dd51b1d163a49b Mon Sep 17 00:00:00 2001 From: Vyacheslav Morov Date: Tue, 11 Feb 2025 12:35:55 +0700 Subject: [PATCH] Add OpenAi Scoring method to ContextRelevance descriptor. --- examples/context_relevance_example.ipynb | 149 ++++++++---------- .../future/descriptors/_context_relevance.py | 56 ++++++- 2 files changed, 121 insertions(+), 84 deletions(-) diff --git a/examples/context_relevance_example.ipynb b/examples/context_relevance_example.ipynb index 00891d7545..e70de08cb5 100644 --- a/examples/context_relevance_example.ipynb +++ b/examples/context_relevance_example.ipynb @@ -6,8 +6,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-02-10T15:35:19.308198Z", - "start_time": "2025-02-10T15:35:18.117683Z" + "end_time": "2025-02-11T05:30:57.579647Z", + "start_time": "2025-02-11T05:30:56.506620Z" } }, "source": [ @@ -24,8 +24,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:35:19.712850Z", - "start_time": "2025-02-10T15:35:19.310939Z" + "end_time": "2025-02-11T05:30:57.644470Z", + "start_time": "2025-02-11T05:30:57.582732Z" } }, "cell_type": "code", @@ -40,8 +40,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:35:19.736783Z", - "start_time": "2025-02-10T15:35:19.715223Z" + "end_time": "2025-02-11T05:30:57.664808Z", + "start_time": "2025-02-11T05:30:57.646202Z" } }, "cell_type": "code", @@ -57,8 +57,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:35:19.762914Z", - "start_time": "2025-02-10T15:35:19.738867Z" + "end_time": "2025-02-11T05:30:57.688615Z", + "start_time": "2025-02-11T05:30:57.667921Z" } }, "cell_type": "code", @@ -331,8 +331,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:35:19.789857Z", - "start_time": "2025-02-10T15:35:19.768453Z" + "end_time": "2025-02-11T05:32:53.120960Z", + "start_time": "2025-02-11T05:32:53.113134Z" } }, "cell_type": "code", @@ -360,34 +360,31 @@ "Name: context, Length: 379, dtype: object" ] }, - "execution_count": 5, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 5 + "execution_count": 12 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:35:24.855002Z", - "start_time": "2025-02-10T15:35:19.792836Z" + "end_time": "2025-02-11T05:33:01.550044Z", + "start_time": "2025-02-11T05:33:01.546983Z" } }, "cell_type": "code", "source": "from evidently.future.datasets import Dataset", "id": "956a112e02bf7108", "outputs": [], - "execution_count": 6 + "execution_count": 13 }, { - "metadata": { - "ExecuteTime": { - "end_time": "2025-02-10T15:36:29.925073Z", - "start_time": "2025-02-10T15:36:24.606246Z" - } - }, + "metadata": {}, "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "\n", "\n", @@ -399,28 +396,16 @@ " data_definition=DataDefinition(\n", " text_columns=[\"question\"],\n", " ),\n", - " descriptors=[ContextRelevance('question', 'context'\n", - " )],\n", + " descriptors=[ContextRelevance('question', 'context', method=\"openai\")],\n", ")" ], - "id": "1c353c389fd7f2b7", - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/liraim/projects/evidently/.venv/lib/python3.12/site-packages/sentence_transformers/SentenceTransformer.py:587: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n", - " sentences_sorted = [sentences[idx] for idx in length_sorted_idx]\n" - ] - } - ], - "execution_count": 11 + "id": "254ba529f5519d4b" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:36:29.948860Z", - "start_time": "2025-02-10T15:36:29.927753Z" + "end_time": "2025-02-11T05:33:05.908701Z", + "start_time": "2025-02-11T05:33:05.886591Z" } }, "cell_type": "code", @@ -497,29 +482,29 @@ "\n", " context \\\n", "index \n", - "2024-04-08 03:55:48.128469 [How do I request medical leave through the em... \n", - "2024-04-08 03:59:47.756913 [To update your direct deposit information on ... \n", - "2024-04-08 06:19:47.717513 [How do I generate payroll reports in the acco... \n", - "2024-04-08 08:22:47.717513 [Information about the company's financial for... \n", - "2024-04-08 08:32:47.717513 [Handling fixed asset acquisitions and disposa... \n", - "2024-04-08 10:52:47.756913 [How can I provide feedback or suggestions for... \n", - "2024-04-08 10:57:47.756913 [If you're having trouble accessing the traini... \n", - "2024-04-08 12:33:47.717513 [Information about the company's expense appro... \n", - "2024-04-08 12:39:47.756913 [If the employee portal is not displaying your... \n", - "2024-04-08 13:43:47.756913 [To submit a request for IT support through th... \n", + "2024-04-08 03:55:48.128469 [What resources are available for improving wo... \n", + "2024-04-08 03:59:47.756913 [How do I create and manage sales forecasts in... \n", + "2024-04-08 06:19:47.717513 [Information about the company's financial per... \n", + "2024-04-08 08:22:47.717513 [Information about the company's revenue recog... \n", + "2024-04-08 08:32:47.717513 [How do I segment customers in the CRM system ... \n", + "2024-04-08 10:52:47.756913 [How do I access my performance review in the ... \n", + "2024-04-08 10:57:47.756913 [How do I create and manage product catalogs i... \n", + "2024-04-08 12:33:47.717513 [Requesting assistance with resolving conflict... \n", + "2024-04-08 12:39:47.756913 [How do I access training sessions or workshop... \n", + "2024-04-08 13:43:47.756913 [To access training materials on the employee ... \n", "\n", " Ranking for question with context: aggregate score \n", "index \n", - "2024-04-08 03:55:48.128469 0.945920 \n", - "2024-04-08 03:59:47.756913 0.960915 \n", - "2024-04-08 06:19:47.717513 0.968615 \n", - "2024-04-08 08:22:47.717513 0.924486 \n", - "2024-04-08 08:32:47.717513 0.966529 \n", - "2024-04-08 10:52:47.756913 0.961541 \n", - "2024-04-08 10:57:47.756913 0.974138 \n", - "2024-04-08 12:33:47.717513 0.952817 \n", - "2024-04-08 12:39:47.756913 0.960415 \n", - "2024-04-08 13:43:47.756913 0.969710 " + "2024-04-08 03:55:48.128469 0.05 \n", + "2024-04-08 03:59:47.756913 0.00 \n", + "2024-04-08 06:19:47.717513 0.05 \n", + "2024-04-08 08:22:47.717513 0.05 \n", + "2024-04-08 08:32:47.717513 0.00 \n", + "2024-04-08 10:52:47.756913 0.00 \n", + "2024-04-08 10:57:47.756913 0.00 \n", + "2024-04-08 12:33:47.717513 0.00 \n", + "2024-04-08 12:39:47.756913 0.00 \n", + "2024-04-08 13:43:47.756913 0.05 " ], "text/html": [ "
\n", @@ -579,8 +564,8 @@ " EU-Spain\n", " production\n", " none\n", - " [How do I request medical leave through the em...\n", - " 0.945920\n", + " [What resources are available for improving wo...\n", + " 0.05\n", " \n", " \n", " 2024-04-08 03:59:47.756913\n", @@ -593,8 +578,8 @@ " EU-Germany\n", " production\n", " none\n", - " [To update your direct deposit information on ...\n", - " 0.960915\n", + " [How do I create and manage sales forecasts in...\n", + " 0.00\n", " \n", " \n", " 2024-04-08 06:19:47.717513\n", @@ -607,8 +592,8 @@ " US-west\n", " production\n", " none\n", - " [How do I generate payroll reports in the acco...\n", - " 0.968615\n", + " [Information about the company's financial per...\n", + " 0.05\n", " \n", " \n", " 2024-04-08 08:22:47.717513\n", @@ -621,8 +606,8 @@ " EU-Germany\n", " production\n", " upvote\n", - " [Information about the company's financial for...\n", - " 0.924486\n", + " [Information about the company's revenue recog...\n", + " 0.05\n", " \n", " \n", " 2024-04-08 08:32:47.717513\n", @@ -635,8 +620,8 @@ " EU-Spain\n", " production\n", " downvote\n", - " [Handling fixed asset acquisitions and disposa...\n", - " 0.966529\n", + " [How do I segment customers in the CRM system ...\n", + " 0.00\n", " \n", " \n", " 2024-04-08 10:52:47.756913\n", @@ -649,8 +634,8 @@ " UK\n", " production\n", " downvote\n", - " [How can I provide feedback or suggestions for...\n", - " 0.961541\n", + " [How do I access my performance review in the ...\n", + " 0.00\n", " \n", " \n", " 2024-04-08 10:57:47.756913\n", @@ -663,8 +648,8 @@ " US-east\n", " production\n", " none\n", - " [If you're having trouble accessing the traini...\n", - " 0.974138\n", + " [How do I create and manage product catalogs i...\n", + " 0.00\n", " \n", " \n", " 2024-04-08 12:33:47.717513\n", @@ -677,8 +662,8 @@ " EU-Spain\n", " production\n", " none\n", - " [Information about the company's expense appro...\n", - " 0.952817\n", + " [Requesting assistance with resolving conflict...\n", + " 0.00\n", " \n", " \n", " 2024-04-08 12:39:47.756913\n", @@ -691,8 +676,8 @@ " US-east\n", " production\n", " none\n", - " [If the employee portal is not displaying your...\n", - " 0.960415\n", + " [How do I access training sessions or workshop...\n", + " 0.00\n", " \n", " \n", " 2024-04-08 13:43:47.756913\n", @@ -705,33 +690,33 @@ " UK\n", " production\n", " none\n", - " [To submit a request for IT support through th...\n", - " 0.969710\n", + " [To access training materials on the employee ...\n", + " 0.05\n", " \n", " \n", "\n", "
" ] }, - "execution_count": 12, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 12 + "execution_count": 16 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-02-10T15:35:38.533606Z", - "start_time": "2025-02-10T15:35:38.530048Z" + "end_time": "2025-02-11T05:31:07.114555Z", + "start_time": "2025-02-11T05:31:07.112144Z" } }, "cell_type": "code", "source": "", "id": "99b8d9e0987cc1f6", "outputs": [], - "execution_count": 8 + "execution_count": 9 } ], "metadata": { diff --git a/src/evidently/future/descriptors/_context_relevance.py b/src/evidently/future/descriptors/_context_relevance.py index 1b883aef18..10af05c394 100644 --- a/src/evidently/future/descriptors/_context_relevance.py +++ b/src/evidently/future/descriptors/_context_relevance.py @@ -8,12 +8,15 @@ from evidently import ColumnType from evidently.base_metric import DisplayName +from evidently.features.llm_judge import BinaryClassificationPromptTemplate from evidently.future.datasets import Dataset from evidently.future.datasets import DatasetColumn from evidently.future.datasets import Descriptor +from evidently.options.base import Options +from evidently.utils.llm.wrapper import OpenAIWrapper -def semantic_similarity_scoring(question: DatasetColumn, context: DatasetColumn) -> DatasetColumn: +def semantic_similarity_scoring(question: DatasetColumn, context: DatasetColumn, options: Options) -> DatasetColumn: from sentence_transformers import SentenceTransformer model_id: str = "all-MiniLM-L6-v2" @@ -41,12 +44,61 @@ def normalized_cosine_distance(left, right): ) +def openai_scoring(question: DatasetColumn, context: DatasetColumn, options: Options) -> DatasetColumn: + # unwrap data to rows + context_column = context.data.name + no_index_context = context.data.reset_index() + context_rows = no_index_context.explode(context_column).reset_index() + + # do scoring + llm_wrapper = OpenAIWrapper("gpt-4o-mini", options) + template = BinaryClassificationPromptTemplate( + criteria="""A "RELEVANT" refers to CONTEXT is relevant to QUESTION. + + "IRRELEVANT" refers to CONTEXT is contradictory or irrelevant to QUESTION. + + Here is a QUESTION + -----question_starts----- + {input} + -----question_ends----- + + Here is a CONTEXT + -----context_starts----- + {context} + -----context_ends----- + + """, + target_category="RELEVANT", + non_target_category="IRRELEVANT", + uncertainty="unknown", + include_reasoning=True, + include_score=True, + pre_messages=[("system", "You are a judge which evaluates text.")], + ) + df = pd.DataFrame({"input": question.data, "context": context.data}).explode("context").reset_index() + questions = template.iterate_messages(df, {"input": "input", "context": "context"}) + results = llm_wrapper.run_batch_sync(questions) + result_data = pd.DataFrame(results) + # wrap scoring to lists back + scind = pd.DataFrame(data={"ind": context_rows["index"], "scores": result_data["score"]}) + rsd = pd.Series( + [list(scind.iloc[x]["scores"].astype(float)) for x in scind.groupby("ind").groups.values()], + index=question.data.index, + ) + + return DatasetColumn( + ColumnType.List, + rsd, + ) + + def mean(scores: List[float]) -> float: return float(np.average(scores)) METHODS = { "semantic_similarity": (semantic_similarity_scoring, mean), + "openai": (openai_scoring, mean), } @@ -83,7 +135,7 @@ def generate_data(self, dataset: Dataset) -> Union[DatasetColumn, Dict[DisplayNa if aggregation_method is None: raise ValueError(f"Aggregation method {self.aggregation_method} not found") - scored_contexts = method(dataset.column(self.input), data) + scored_contexts = method(dataset.column(self.input), data, Options()) aggregated_scores = scored_contexts.data.apply(aggregation_method) result = { f"{self.alias}: aggregate score": DatasetColumn(ColumnType.Numerical, aggregated_scores),