diff --git a/langgraph_agent_app_sample_code/01_data_pipeline.ipynb b/langgraph_agent_app_sample_code/01_data_pipeline.ipynb new file mode 100644 index 0000000..5ae00ae --- /dev/null +++ b/langgraph_agent_app_sample_code/01_data_pipeline.ipynb @@ -0,0 +1,831 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "7c756f50-2063-4a07-b964-e5d6de29abb4", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "# Unstructured data pipeline for the Agent's Retriever\n", + "\n", + "By the end of this notebook, you will have transformed your unstructured documents into a vector index that can be queried by your Agent.\n", + "\n", + "This means:\n", + "- Documents loaded into a delta table.\n", + "- Documents are chunked.\n", + "- Chunks have been embedded with an embedding model and stored in a vector index.\n", + "\n", + "The important resulting artifact of this notebook is the chunked vector index. This will be used in the next notebook to power our Retriever." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "d3777205-4dfe-418c-9d21-c67961a18070", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 👉 START HERE: How to Use This Notebook\n", + "\n", + "Follow these steps to build and refine your data pipeline's quality:\n", + "\n", + "1. **Build a v0 index with default settings**\n", + " - Configure the data source and destination tables in the `1️⃣ 📂 Data source & destination configuration` cells\n", + " - Press `Run All` to create the vector index.\n", + "\n", + " *Note: While you can adjust the other settings and modify the parsing/chunking code, we suggest doing so only after evaluating your Agent's quality so you can make improvements that specifically address root causes of quality issues.*\n", + "\n", + "2. **Use later notebooks to integrate the retriever into an the agent and evaluate the agent/retriever's quality.**\n", + "\n", + "3. **If the evaluation results show retrieval issues as a root cause, use this notebook to iterate on your data pipeline's code & config.** Below are some potential fixes you can try, see the AI Cookbook's [debugging retrieval issues](https://ai-cookbook.io/nbs/5-hands-on-improve-quality-step-1-retrieval.html) section for details.**\n", + " - Add missing, but relevant source documents into in the index.\n", + " - Resolve any conflicting information in source documents.\n", + " - Adjust the data pipeline configuration:\n", + " - Modify chunk size or overlap.\n", + " - Experiment with different embedding models.\n", + " - Adjust the data pipeline code:\n", + " - Create a custom parser or use different parsing libraries.\n", + " - Develop a custom chunker or use different chunking techniques.\n", + " - Extract additional metadata for each document.\n", + " - Adjust the Agent's code/config in subsequent notebooks:\n", + " - Change the number of documents retrieved (K).\n", + " - Try a re-ranker.\n", + " - Use hybrid search.\n", + " - Apply extracted metadata as filters.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "1a6053b9-3135-4097-9ed0-64bdb03a6b9f", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "**Important note:** Throughout this notebook, we indicate which cells you:\n", + "- ✅✏️ *should* customize - these cells contain code & config with business logic that you should edit to meet your requirements & tune quality\n", + "- 🚫✏️ *typically will not* customize - these cells contain boilerplate code required to execute the pipeline\n", + "\n", + "*Cells that don't require customization still need to be run! You CAN change these cells, but if this is the first time using this notebook, we suggest not doing so.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "16b35cfd-7c99-4419-8978-33939faf24a6", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### Install Python libraries (Databricks Notebook only)\n", + "\n", + "🚫✏️ Only modify if you need additional packages in your code changes to the document parsing or chunking logic.\n", + "\n", + "Versions of Databricks code are not locked since Databricks ensures changes are backwards compatible.\n", + "Versions of open source packages are locked since package authors often make backwards compatible changes" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6b4eebb3-448a-4236-99fb-19e44858e3c6", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install -qqqq -U -r requirements.txt\n", + "%pip install -qqqq -U -r requirements_datapipeline.txt\n", + "%pip install python-box\n", + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "b9d8671d-5f12-4c52-a537-513ea4a10dbb", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### Connect to Databricks (Local IDE only)\n", + "\n", + "If running from an IDE with [`databricks-connect`](https://docs.databricks.com/en/dev-tools/databricks-connect/python/index.html), connect to a Spark session & install the necessary packages on that cluster." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "23a0815a-d578-454e-8e9f-b39bcbe413aa", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.databricks_utils import get_cluster_url\n", + "from cookbook.databricks_utils import get_active_cluster_id\n", + "from cookbook.databricks_utils.install_cluster_library import install_requirements\n", + "\n", + "# UNCOMMENT TO INSTALL PACKAGES ON THE ACTIVE CLUSTER; this is code that is not super battle tested.\n", + "# cluster_id = get_active_cluster_id()\n", + "# print(f\"Installing packages on the active cluster: {get_cluster_url(cluster_id)}\")\n", + "\n", + "\n", + "# install_requirements(cluster_id, \"requirements.txt\")\n", + "# install_requirements(cluster_id, \"requirements_datapipeline.txt\")\n", + "\n", + "# THIS MUST BE DONE MANUALLY! TODO: Automate it.\n", + "# - Go to openai_sdk_agent_app_sample_code/\n", + "# - Run `poetry build`\n", + "# - Copy the wheel file to a UC Volume or Workspace folder\n", + "# - Go to the cluster's Libraries page and install the wheel file as a new library\n", + "\n", + "# Get Spark session if using Databricks Connect from an IDE\n", + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if not du.is_in_databricks_notebook():\n", + " from databricks.connect import DatabricksSession\n", + "\n", + " spark = DatabricksSession.builder.getOrCreate()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "257fc9e5-4968-469d-98c1-f11af124c92e", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### 🚫✏️ Load data pipeline configuration from a YAML\n", + "\n", + "This allows the configuration to be loaded referenced by the Agent's notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a97a379d-e933-45f6-ae7e-6c317e43d009", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# load config\n", + "import yaml\n", + "from pathlib import Path\n", + "from box import Box\n", + "\n", + "data_pipeline_conf = Box(yaml.safe_load(Path(\"./configs/data_pipeline_config.yaml\").read_text()))\n", + "print(data_pipeline_conf)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "a28cbf99-c4ca-4adc-905a-e7ebfe015730", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### 🛑 If you are running your initial data pipeline, you do not need to configure anything else, you can just `Run All` the notebook cells before. You can modify these cells later to tune the quality of your data pipeline by changing the parsing logic." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "95b6971b-b00b-4f42-bbe8-cc64eea2fff8", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "## 3️⃣ ⌨️ Data pipeline code\n", + "\n", + "The code below executes the data pipeline. You can modify the below code as indicated to implement different parsing or chunking strategies or to extract additional metadata fields" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "4f9ebf83-9536-48cc-b087-2c8e324ff542", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### 🛑 Make sure to populate the volume with source files before running the parsing code below." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "c85ddc92-10c5-405c-ae78-8ded5462333e", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### Pipeline step 1: Load & parse documents into a Delta Table\n", + "\n", + "In this step, we'll load files from the UC Volume defined in `source_config` into the Delta Table `storage_config.parsed_docs_table` . The contents of each file will become a separate row in our delta table.\n", + "\n", + "The path to the source document will be used as the `doc_uri` which is displayed to your end users in the Agent Evalution web application.\n", + "\n", + "After you test your POC with stakeholders, you can return here to change the parsing logic or extraction additional metadata about the documents to help improve the quality of your retriever." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "27466460-1ee7-4fe4-8faf-da9ddff11847", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "##### ✅✏️ Customize the parsing function\n", + "\n", + "This default implementation parses PDF, HTML, and DOCX files using open source libraries. Adjust `file_parser(...)` and `ParserReturnValue` in `cookbook/data_pipeline/default_parser.py` to add change the parsing logic, add support for more file types, or extract additional metadata about each document." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d09fd38c-5b7b-47c5-aa6a-ff571ce2f83b", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.data_pipeline.default_parser import file_parser, ParserReturnValue\n", + "\n", + "# Print the code of file_parser function for inspection\n", + "import inspect\n", + "print(inspect.getsource(ParserReturnValue))\n", + "print(inspect.getsource(file_parser))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "61034803-4bdd-4f0b-b173-a82448ee1790", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "The below cell is debugging code to test your parsing function on a single record. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "48a3ab67-2e30-4e39-b05e-3a8ff304fd5b", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import os\n", + "from cookbook.data_pipeline.parse_docs import load_files_to_df\n", + "from pyspark.sql import functions as F\n", + "\n", + "\n", + "source_conf = data_pipeline_conf.source\n", + "\n", + "raw_files_df = load_files_to_df(\n", + " spark=spark,\n", + " source_path=source_conf.volume_path,\n", + ")\n", + "\n", + "print(f\"Loaded {raw_files_df.count()} files from {source_conf.volume_path}. Files: {os.listdir(source_conf.volume_path)}\")\n", + "\n", + "test_records_dict = raw_files_df.toPandas().to_dict(orient=\"records\")\n", + "\n", + "for record in test_records_dict:\n", + " print()\n", + " print(\"Testing parsing for file: \", record[\"path\"])\n", + " print()\n", + " test_result = file_parser(raw_doc_contents_bytes=record['content'], doc_path=record['path'], modification_time=record['modificationTime'], doc_bytes_length=record['length'])\n", + " print(test_result)\n", + " break # pause after 1 file. if you want to test more files, remove the break statement\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "9fb6db6c-faa0-4dac-be84-a832bbbb49b9", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "🚫✏️ The below cell is boilerplate code to apply the parsing function using Spark." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "165706b2-5824-42e7-a22b-3ca0edfd0a77", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.data_pipeline.parse_docs import (\n", + " load_files_to_df,\n", + " apply_parsing_fn,\n", + " check_parsed_df_for_errors,\n", + " check_parsed_df_for_empty_parsed_files,\n", + ")\n", + "from cookbook.data_pipeline.utils.typed_dicts_to_spark_schema import (\n", + " typed_dicts_to_spark_schema,\n", + ")\n", + "from cookbook.databricks_utils import get_table_url\n", + "\n", + "output_config = data_pipeline_conf.output\n", + "# Tune this parameter to optimize performance. More partitions will improve performance, but may cause out of memory errors if your cluster is too small.\n", + "NUM_PARTITIONS = 50\n", + "\n", + "# Load the UC Volume files into a Spark DataFrame\n", + "raw_files_df = load_files_to_df(\n", + " spark=spark,\n", + " source_path=source_conf.volume_path,\n", + ").repartition(NUM_PARTITIONS)\n", + "\n", + "# Apply the parsing UDF to the Spark DataFrame\n", + "parsed_files_df = apply_parsing_fn(\n", + " raw_files_df=raw_files_df,\n", + " # Modify this function to change the parser, extract additional metadata, etc\n", + " parse_file_fn=file_parser,\n", + " # The schema of the resulting Delta Table will follow the schema defined in ParserReturnValue\n", + " parsed_df_schema=typed_dicts_to_spark_schema(ParserReturnValue),\n", + ")\n", + "\n", + "parsed_data_fully_qualified_table_name = f\"{source_conf.uc_catalog_name}.{source_conf.uc_schema_name}.{output_config.parsed_docs_table}\"\n", + "\n", + "# Write to a Delta Table\n", + "parsed_files_df.write.mode(\"overwrite\").option(\"overwriteSchema\", \"true\").saveAsTable(\n", + " parsed_data_fully_qualified_table_name\n", + ")\n", + "\n", + "# Get resulting table\n", + "parsed_files_df = spark.table(parsed_data_fully_qualified_table_name)\n", + "parsed_files_no_errors_df = parsed_files_df.filter(\n", + " parsed_files_df.parser_status == \"SUCCESS\"\n", + ")\n", + "\n", + "# Show successfully parsed documents\n", + "print(\n", + " f\"Parsed {parsed_files_df.count()} / {parsed_files_no_errors_df.count()} documents successfully. Inspect `parsed_files_no_errors_df` or visit {get_table_url(parsed_data_fully_qualified_table_name)} to see all parsed documents, including any errors.\"\n", + ")\n", + "display(parsed_files_no_errors_df.toPandas())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "73eba7c4-fe30-4599-918c-12ec9e2039a9", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Show any parsing failures or successfully parsed files that resulted in an empty document." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "9cfac97c-9f01-46d9-a281-4d18cbd3208d", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "\n", + "# Any documents that failed to parse\n", + "is_error, msg, failed_docs_df = check_parsed_df_for_errors(parsed_files_df)\n", + "if is_error:\n", + " display(failed_docs_df.toPandas())\n", + " raise Exception(msg)\n", + " \n", + "# Any documents that returned empty parsing results\n", + "is_error, msg, empty_docs_df = check_parsed_df_for_empty_parsed_files(parsed_files_df)\n", + "if is_error:\n", + " display(empty_docs_df.toPandas())\n", + " raise Exception(msg)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "e21c84e8-7682-4a7a-86fc-7f4f990bb490", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### Pipeline step 2: Compute chunks of documents\n", + "\n", + "In this step, we will split our documents into smaller chunks so they can be indexed in our vector database." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "eecd460c-f287-47ce-98f1-cea78a1f3f64", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "\n", + "##### ✅✏️ Chunking logic.\n", + "\n", + "We provide a default implementation of a recursive text splitter. To create your own chunking logic, adapt the `get_recursive_character_text_splitter()` function inside `cookbook.data_pipeline.recursive_character_text_splitter.py`." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "02c40228-f933-4af8-9121-ed2efa0985dd", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.data_pipeline.recursive_character_text_splitter import (\n", + " get_recursive_character_text_splitter,\n", + ")\n", + "\n", + "chunking_conf = data_pipeline_conf.chunking_config\n", + "\n", + "# Get the chunking function\n", + "recursive_character_text_splitter_fn = get_recursive_character_text_splitter(\n", + " model_serving_endpoint=chunking_conf.embedding_model_endpoint,\n", + " chunk_size_tokens=chunking_conf.chunk_size_tokens,\n", + " chunk_overlap_tokens=chunking_conf.chunk_overlap_tokens,\n", + ")\n", + "\n", + "# Determine which columns to propagate from the docs table to the chunks table.\n", + "\n", + "# Get the columns from the parser except for the content\n", + "# You can modify this to adjust which fields are propagated from the docs table to the chunks table.\n", + "propagate_columns = [\n", + " field.name\n", + " for field in typed_dicts_to_spark_schema(ParserReturnValue).fields\n", + " if field.name != \"content\"\n", + "]\n", + "\n", + "# If you want to implement retrieval strategies such as presenting the entire document vs. the chunk to the LLM, include `contentich contains the doc's full parsed text. By default this is not included because the size of contcontentquite large and cause performance issues.\n", + "# propagate_columns = [\n", + "# field.name\n", + "# for field in typed_dicts_to_spark_schema(ParserReturnValue).fields\n", + "# ]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "b17add2c-e7f0-4903-8ae9-40ca0633a8d5", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "🚫✏️ Run the chunking function within Spark" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0dfa90f8-c4dc-4485-8fa8-dcd4c7d40618", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.data_pipeline.chunk_docs import apply_chunking_fn\n", + "from cookbook.databricks_utils import get_table_url\n", + "\n", + "# Tune this parameter to optimize performance. More partitions will improve performance, but may cause out of memory errors if your cluster is too small.\n", + "NUM_PARTITIONS = 50\n", + "\n", + "# Load parsed docs\n", + "parsed_files_df = spark.table(output_config.parsed_docs_table).repartition(NUM_PARTITIONS)\n", + "\n", + "chunked_docs_df = chunked_docs_table = apply_chunking_fn(\n", + " # The source documents table.\n", + " parsed_docs_df=parsed_files_df,\n", + " # The chunking function that takes a string (document) and returns a list of strings (chunks).\n", + " chunking_fn=recursive_character_text_splitter_fn,\n", + " # Choose which columns to propagate from the docs table to chunks table. `doc_uri` column is required we can propagate the original document URL to the Agent's web app.\n", + " propagate_columns=propagate_columns,\n", + ")\n", + "\n", + "chunked_data_fully_qualified_table_name = f\"{source_conf.uc_catalog_name}.{source_conf.uc_schema_name}.{output_config.chunked_docs_table}\"\n", + "\n", + "\n", + "# Write to Delta Table\n", + "chunked_docs_df.write.mode(\"overwrite\").option(\n", + " \"overwriteSchema\", \"true\"\n", + ").saveAsTable(chunked_data_fully_qualified_table_name)\n", + "\n", + "# Get resulting table\n", + "chunked_docs_df = spark.table(chunked_data_fully_qualified_table_name)\n", + "\n", + "# Show number of chunks created\n", + "print(f\"Created {chunked_docs_df.count()} chunks. Inspect `chunked_docs_df` or visit {get_table_url(chunked_data_fully_qualified_table_name)} to see the results.\")\n", + "\n", + "# enable CDC feed for VS index sync\n", + "cdc_results = spark.sql(f\"ALTER TABLE {chunked_data_fully_qualified_table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)\")\n", + "\n", + "# Show chunks\n", + "display(chunked_docs_df.toPandas())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "9fe923a8-89c2-4852-9cea-98074b3ce404", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### 🚫✏️ Pipeline step 3: Create the vector index\n", + "\n", + "In this step, we'll embed the documents to compute the vector index over the chunks and create our retriever index that will be used to query relevant documents to the user question. The embedding pipeline is handled within Databricks Vector Search using [Delta Sync](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d53faa42-2a65-40b0-8fc1-6c27e88df6d0", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.data_pipeline.build_retriever_index import build_retriever_index\n", + "from cookbook.databricks_utils import get_table_url\n", + "\n", + "vector_search_fully_qualified_index_name = f\"{source_conf.uc_catalog_name}.{source_conf.uc_schema_name}.{output_config.vector_index}\"\n", + "is_error, msg = retriever_index_result = build_retriever_index(\n", + " # Spark requires `` to escape names with special chars, VS client does not.\n", + " chunked_docs_table_name=chunked_data_fully_qualified_table_name,\n", + " vector_search_endpoint=output_config.vector_search_endpoint,\n", + " vector_search_index_name=vector_search_fully_qualified_index_name,\n", + "\n", + " # Must match the embedding endpoint you used to chunk your documents\n", + " embedding_endpoint_name=chunking_conf.embedding_model_endpoint,\n", + "\n", + " # Set to true to re-create the vector search endpoint when re-running the data pipeline. If set to True, syncing will not work if re-run the pipeline and change the schema of chunked_docs_table_name. Keeping this as False will allow Vector Search to avoid recomputing embeddings for any row with that has a chunk_id that was previously computed.\n", + " force_delete_index_before_create=False,\n", + ")\n", + "if is_error:\n", + " raise Exception(msg)\n", + "else:\n", + " print(\"NOTE: This cell will complete before the vector index has finished syncing/embedding your chunks & is ready for queries!\")\n", + " print(f\"View sync status here: {get_table_url(vector_search_fully_qualified_index_name)}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "1a1ad14b-2573-4485-8369-d417f7a548f6", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### 🚫✏️ Print links to view the resulting tables/index" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0cd40431-4cd3-4cc9-b38d-5ab817c40043", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.databricks_utils import get_table_url\n", + "\n", + "print()\n", + "print(f\"Parsed docs table: {get_table_url(parsed_data_fully_qualified_table_name)}\\n\")\n", + "print(f\"Chunked docs table: {get_table_url(chunked_data_fully_qualified_table_name)}\\n\")\n", + "print(f\"Vector search index: {get_table_url(vector_search_fully_qualified_index_name)}\\n\")" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 2 + }, + "notebookName": "01_data_pipeline", + "widgets": {} + }, + "kernelspec": { + "display_name": "genai-cookbook-T2SdtsNM-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/langgraph_agent_app_sample_code/02_create_synthetic_eval.ipynb b/langgraph_agent_app_sample_code/02_create_synthetic_eval.ipynb new file mode 100644 index 0000000..6521cb2 --- /dev/null +++ b/langgraph_agent_app_sample_code/02_create_synthetic_eval.ipynb @@ -0,0 +1,350 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "698f4adb-d383-4a76-86b7-94716c60a479", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "## 👉 START HERE: How to use this notebook\n", + "\n", + "### Step 1: Create synthetic evaluation data\n", + "\n", + "To measure your Agent's quality, you need a diverse, representative evaluation set. This notebook turns your unstructured documents into a high-quality synthetic evaluation set so that you can start to evaluate and improve your Agent's quality before subject matter experts are available to label data.\n", + "\n", + "This notebook does the following:\n", + "1. \n", + "\n", + "THIS DOES NOT WORK FROM LOCAL IDE YET." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "27267d1c-d1a4-4acb-a1bf-381cb2bc542b", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "**Important note:** Throughout this notebook, we indicate which cells you:\n", + "- ✅✏️ *should* customize - these cells contain config settings to change\n", + "- 🚫✏️ *typically will not* customize - these cells contain code that is parameterized by your configuration.\n", + "\n", + "*Cells that don't require customization still need to be run!*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "c79b4f3b-e8b3-4bc1-ae73-7b3dfd3abdcd", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Install Python libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "cbcdef70-657e-4f12-b564-90d0f5b74e42", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install -qqqq -U -r requirements.txt\n", + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "5aa247a6-8af1-4a28-8353-1c59b2319978", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Connect to Databricks\n", + "\n", + "If running locally in an IDE using Databricks Connect, connect the Spark client & configure MLflow to use Databricks Managed MLflow. If this running in a Databricks Notebook, these values are already set." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3ff747f0-5989-40ca-8a75-10c1fcb6ede6", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if not du.is_in_databricks_notebook():\n", + " from databricks.connect import DatabricksSession\n", + " import os\n", + "\n", + " spark = DatabricksSession.builder.getOrCreate()\n", + " os.environ[\"MLFLOW_TRACKING_URI\"] = \"databricks\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "76e8dfcc-88e9-4f63-8b18-6f58861d0498", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Load the Agent's storage locations\n", + "\n", + "This notebook writes to the evaluation set table that you specified in the [Agent setup](02_agent_setup.ipynb) notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d50957ea-3e35-4131-a217-05df65a0d9be", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import yaml\n", + "from pathlib import Path\n", + "from box import Box\n", + "from cookbook.databricks_utils import get_table_url\n", + "\n", + "# Load the Agent's storage configuration\n", + "agent_storage_config = Box(yaml.safe_load(Path(\"./configs/agent_storage_config.yaml\").read_text()))\n", + "print(agent_storage_config)\n", + "\n", + "# Check if the evaluation set already exists\n", + "try:\n", + " eval_dataset = spark.table(agent_storage_config.evaluation_set_uc_table)\n", + " if eval_dataset.count() > 0:\n", + " print(f\"Evaluation set {get_table_url(agent_storage_config.evaluation_set_uc_table)} already exists! By default, this notebook will append to the evaluation dataset. If you would like to overwrite the existing evaluation set, please delete the table before running this notebook.\")\n", + " else:\n", + " print(f\"Evaluation set {get_table_url(agent_storage_config.evaluation_set_uc_table)} exists, but is empty! By default, this notebook will NOT change the schema of this table - if you experience schema related errors, drop this table before running this notebook so it can be recreated with the correct schema.\")\n", + "except Exception:\n", + " print(f\"Evaluation set `{agent_storage_config.evaluation_set_uc_table}` does not exist. This notebook will create a new Delta Table at this location.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "87da7b3f-bc19-4423-98f7-fa11e7d3b2b6", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### ✅✏️ Load the source documents for synthetic evaluation data generation\n", + "\n", + "Most often, this will be the same as the document output table from the [data pipeline](01_data_pipeline.ipynb).\n", + "\n", + "Here, we provide code to load the documents table that was created in the [data pipeline](01_data_pipeline.ipynb).\n", + "\n", + "Alternatively, this can be a Spark DataFrame, Pandas DataFrame, or list of dictionaries with the following keys/columns:\n", + "- `doc_uri`: A URI pointing to the document.\n", + "- `content`: The content of the document." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "4f17c6d0-633e-444f-b7e6-cff3b28f6e18", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# load config\n", + "import yaml\n", + "from pathlib import Path\n", + "from box import Box\n", + "\n", + "data_pipeline_config = Box(\n", + " yaml.safe_load(Path(\"./configs/data_pipeline_config.yaml\").read_text())\n", + ")\n", + "print(data_pipeline_config)\n", + "\n", + "source_config = data_pipeline_config.source\n", + "output_config = data_pipeline_config.output\n", + "\n", + "parsed_data_fully_qualified_table_name = f\"{source_config.uc_catalog_name}.{source_config.uc_schema_name}.{output_config.parsed_docs_table}\"\n", + "\n", + "source_documents = spark.table(parsed_data_fully_qualified_table_name)\n", + "\n", + "display(source_documents.toPandas())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "21a5ef1f-9d3e-4822-ba0f-44eab36f24bb", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### ✅✏️ Run the synthetic evaluation data generation\n", + "\n", + "Optionally, you can customize the guidelines to guide the synthetic data generation. By default, guidelines are not applied - to apply the guidelines, uncomment `guidelines=guidelines` in the `generate_evals_df(...)` call. See our [documentation](https://docs.databricks.com/en/generative-ai/agent-evaluation/synthesize-evaluation-set.html) for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a7cb950a-84b1-4e1d-a7fb-5179a0aa69de", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from databricks.agents.eval import generate_evals_df\n", + "\n", + "# NOTE: The guidelines you provide are a free-form string. The markdown string below is the suggested formatting for the set of guidelines, however you are free\n", + "# to add your sections here. Note that this will be prompt-engineering an LLM that generates the synthetic data, so you may have to iterate on these guidelines before\n", + "# you get the results you desire.\n", + "guidelines = \"\"\"\n", + "# Task Description\n", + "The Agent is a RAG chatbot that answers questions about using Spark on Databricks. The Agent has access to a corpus of Databricks documents, and its task is to answer the user's questions by retrieving the relevant docs from the corpus and synthesizing a helpful, accurate response. The corpus covers a lot of info, but the Agent is specifically designed to interact with Databricks users who have questions about Spark. So questions outside of this scope are considered irrelevant.\n", + "\n", + "# User personas\n", + "- A developer who is new to the Databricks platform\n", + "- An experienced, highly technical Data Scientist or Data Engineer\n", + "\n", + "# Example questions\n", + "- what API lets me parallelize operations over rows of a delta table?\n", + "- Which cluster settings will give me the best performance when using Spark?\n", + "\n", + "# Additional Guidelines\n", + "- Questions should be succinct, and human-like\n", + "\"\"\"\n", + "\n", + "synthesized_evals_df = generate_evals_df(\n", + " docs=source_documents,\n", + " # The number of evaluations to generate for each doc.\n", + " num_evals=10,\n", + " # A optional set of guidelines that help guide the synthetic generation. This is a free-form string that will be used to prompt the generation.\n", + " # guidelines=guidelines\n", + ")\n", + "\n", + "# Write the synthetic evaluation data to the evaluation set table\n", + "spark.createDataFrame(synthesized_evals_df).write.format(\"delta\").mode(\"append\").saveAsTable(agent_storage_config.evaluation_set_uc_table)\n", + "\n", + "# Display the synthetic evaluation data\n", + "eval_set_df = spark.table(agent_storage_config.evaluation_set_uc_table)\n", + "display(eval_set_df.toPandas())" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": { + "base_environment": "", + "client": "1" + }, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "02_create_synthetic_eval", + "widgets": {} + }, + "kernelspec": { + "display_name": "genai-cookbook-T2SdtsNM-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/langgraph_agent_app_sample_code/03_create_tools.ipynb b/langgraph_agent_app_sample_code/03_create_tools.ipynb new file mode 100644 index 0000000..334ad70 --- /dev/null +++ b/langgraph_agent_app_sample_code/03_create_tools.ipynb @@ -0,0 +1,1333 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "31661828-f9bb-4fc2-a1bd-94424a27ed52", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "## 👉 START HERE: How to use this notebook\n", + "\n", + "# Step 2: Create tools for your Agent\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "5d9f685a-fdb7-49a4-9e3a-a4a9e964d045", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "\n", + "**Important note:** Throughout this notebook, we indicate which cell's code you:\n", + "- ✅✏️ should customize - these cells contain code & config with business logic that you should edit to meet your requirements & tune quality.\n", + "- 🚫✏️ should not customize - these cells contain boilerplate code required to load/save/execute your Agent\n", + "\n", + "*Cells that don't require customization still need to be run! You CAN change these cells, but if this is the first time using this notebook, we suggest not doing so.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "bb4f8cc0-1797-4beb-a9f2-df21a9db79f0", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Install Python libraries\n", + "\n", + "You do not need to modify this cell unless you need additional Python packages in your Agent." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6d4030e8-ae97-4351-bebd-9651d283578f", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install -qqqq -U -r requirements.txt\n", + "# Restart to load the packages into the Python environment\n", + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "5fa580a1-524b-4058-830d-7d3a72169fde", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Connect to Databricks\n", + "\n", + "If running locally in an IDE using Databricks Connect, connect the Spark client & configure MLflow to use Databricks Managed MLflow. If this running in a Databricks Notebook, these values are already set." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a08da9d7-152a-44c8-9609-2ffebcc1ad25", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if not du.is_in_databricks_notebook():\n", + " from databricks.connect import DatabricksSession\n", + " import os\n", + "\n", + " spark = DatabricksSession.builder.getOrCreate()\n", + " os.environ[\"MLFLOW_TRACKING_URI\"] = \"databricks\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "6be96e1a-d2fa-4acd-b54e-d7fe85d0034d", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Load the Agent's UC storage locations; set up MLflow experiment\n", + "\n", + "This notebook uses the UC model, MLflow Experiment, and Evaluation Set that you specified in the [Agent setup](02_agent_setup.ipynb) notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "413ef673-1912-4600-bde2-1beaaa4f9919", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import yaml\n", + "from pathlib import Path\n", + "import mlflow \n", + "from box import Box\n", + "from cookbook.databricks_utils import get_table_url\n", + "from cookbook.databricks_utils import get_mlflow_experiment_url\n", + "\n", + "# Load the Agent's storage configuration\n", + "agent_storage_config = Box(yaml.safe_load(Path(\"./configs/agent_storage_config.yaml\").read_text()))\n", + "print(agent_storage_config)\n", + "\n", + "\n", + "# set the MLflow experiment\n", + "experiment_info = mlflow.set_experiment(agent_storage_config.mlflow_experiment_name)\n", + "# If running in a local IDE, set the MLflow experiment name as an environment variable\n", + "os.environ[\"MLFLOW_EXPERIMENT_NAME\"] = agent_storage_config.mlflow_experiment_name\n", + "\n", + "print(f\"View the MLflow Experiment `{agent_storage_config.mlflow_experiment_name}` at {get_mlflow_experiment_url(experiment_info.experiment_id)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "d3bfc354-785f-4e0a-8245-fe05389769d7", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "# create tools\n", + "\n", + "- we will store all tools in the `user_tools` folder\n", + "- first, create a local function & test it with pytest\n", + "- then, deploy it as a UC tool & test it with pytest\n", + "- then, add the tool to the Agent " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "1c054346-e4f6-42a7-9403-9882f2450b2b", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "always reload the tool's code" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "320ca179-916a-4ec7-b012-dd9915adbbc2", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 3" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "7b0d70d2-17b1-4fdf-a5f9-da75332c4a5c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "## lets do an example of a simple, but fake tool that translates old to new SKUs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "54ae1351-99f9-4e42-85a1-e56b7882509b", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "1, create the python function that will become your UC function. you need to annotate the function with docstrings & type hints - these are used to create the tool's metadata in UC." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "b1375dbe-003c-4aee-bb1d-f5922c716242", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%%writefile tools/sample_tool.py\n", + "\n", + "def sku_sample_translator(old_sku: str) -> str:\n", + " \"\"\"\n", + " Translates a pre-2024 SKU formatted as \"OLD-XXX-YYYY\" to the new SKU format \"NEW-YYYY-XXX\".\n", + "\n", + " Args:\n", + " old_sku (str): The old SKU in the format \"OLD-XXX-YYYY\".\n", + "\n", + " Returns:\n", + " str: The new SKU in the format \"NEW-YYYY-XXX\".\n", + "\n", + " Raises:\n", + " ValueError: If the SKU format is invalid, providing specific error details.\n", + " \"\"\"\n", + " import re\n", + "\n", + " if not isinstance(old_sku, str):\n", + " raise ValueError(\"SKU must be a string\")\n", + "\n", + " # Normalize input by removing extra whitespace and converting to uppercase\n", + " old_sku = old_sku.strip().upper()\n", + "\n", + " # Define the regex pattern for the old SKU format\n", + " pattern = r\"^OLD-([A-Z]{3})-(\\d{4})$\"\n", + "\n", + " # Match the old SKU against the pattern\n", + " match = re.match(pattern, old_sku)\n", + " if not match:\n", + " if not old_sku.startswith(\"OLD-\"):\n", + " raise ValueError(\"SKU must start with 'OLD-'\")\n", + " if not re.match(r\"^OLD-[A-Z]{3}-\\d{4}$\", old_sku):\n", + " raise ValueError(\n", + " \"SKU format must be 'OLD-XXX-YYYY' where X is a letter and Y is a digit\"\n", + " )\n", + " raise ValueError(\"Invalid SKU format\")\n", + "\n", + " # Extract the letter code and numeric part\n", + " letter_code, numeric_part = match.groups()\n", + "\n", + " # Additional validation for numeric part\n", + " if not (1 <= int(numeric_part) <= 9999):\n", + " raise ValueError(\"Numeric part must be between 0001 and 9999\")\n", + "\n", + " # Construct the new SKU\n", + " new_sku = f\"NEW-{numeric_part}-{letter_code}\"\n", + " return new_sku\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "aa0474ad-8ee2-4148-9018-aa78b37ded7c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Now, let's import the tool and test it locally" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "034892a0-c83c-4c95-b318-83d369f9a153", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from tools.sample_tool import sku_sample_translator\n", + "\n", + "sku_sample_translator(\"OLD-XXX-1234\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "3e7b9a25-c7e6-48ac-9d3e-86ec325d524c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "now, lets write some pyTest unit tests for the tool - these are just samples, you will need to write your own" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7ac4fd62-270c-443a-8b17-1b6bb8e00f58", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%%writefile tools/test_sample_tool.py\n", + "import pytest\n", + "from tools.sample_tool import sku_sample_translator\n", + "\n", + "\n", + "\n", + "def test_valid_sku_translation():\n", + " \"\"\"Test successful SKU translation with valid input.\"\"\"\n", + " assert sku_sample_translator(\"OLD-ABC-1234\") == \"NEW-1234-ABC\"\n", + " assert sku_sample_translator(\"OLD-XYZ-0001\") == \"NEW-0001-XYZ\"\n", + " assert sku_sample_translator(\"old-def-5678\") == \"NEW-5678-DEF\" # Test case insensitivity\n", + "\n", + "\n", + "def test_whitespace_handling():\n", + " \"\"\"Test that the function handles extra whitespace correctly.\"\"\"\n", + " assert sku_sample_translator(\" OLD-ABC-1234 \") == \"NEW-1234-ABC\"\n", + " assert sku_sample_translator(\"\\tOLD-ABC-1234\\n\") == \"NEW-1234-ABC\"\n", + "\n", + "\n", + "def test_invalid_input_type():\n", + " \"\"\"Test that non-string inputs raise ValueError.\"\"\"\n", + " with pytest.raises(ValueError, match=\"SKU must be a string\"):\n", + " sku_sample_translator(123)\n", + " with pytest.raises(ValueError, match=\"SKU must be a string\"):\n", + " sku_sample_translator(None)\n", + "\n", + "\n", + "def test_invalid_prefix():\n", + " \"\"\"Test that SKUs not starting with 'OLD-' raise ValueError.\"\"\"\n", + " with pytest.raises(ValueError, match=\"SKU must start with 'OLD-'\"):\n", + " sku_sample_translator(\"NEW-ABC-1234\")\n", + " with pytest.raises(ValueError, match=\"SKU must start with 'OLD-'\"):\n", + " sku_sample_translator(\"XXX-ABC-1234\")\n", + "\n", + "\n", + "def test_invalid_format():\n", + " \"\"\"Test various invalid SKU formats.\"\"\"\n", + " invalid_skus = [\n", + " \"OLD-AB-1234\", # Too few letters\n", + " \"OLD-ABCD-1234\", # Too many letters\n", + " \"OLD-123-1234\", # Numbers instead of letters\n", + " \"OLD-ABC-123\", # Too few digits\n", + " \"OLD-ABC-12345\", # Too many digits\n", + " \"OLD-ABC-XXXX\", # Letters instead of numbers\n", + " \"OLD-A1C-1234\", # Mixed letters and numbers in middle\n", + " ]\n", + "\n", + " for sku in invalid_skus:\n", + " with pytest.raises(\n", + " ValueError,\n", + " match=\"SKU format must be 'OLD-XXX-YYYY' where X is a letter and Y is a digit\",\n", + " ):\n", + " sku_sample_translator(sku)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "8b6add46-090e-4160-afdf-1e7a9a5ff46c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "now, lets run the tests" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "da7b4dba-a9cc-4f0b-9fdd-606afd01d6b0", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import pytest\n", + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if du.is_in_databricks_notebook():\n", + " import sys\n", + " sys.dont_write_bytecode = True # Skip writing .pyc files to the bytecode cache on the cluster.\n", + "\n", + "# Run tests from test_sku_translator.py\n", + "pytest.main([\"-v\", \"tools/test_sample_tool.py\", \"-o cache_dir=/tmp/my_cache_dir\"])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "1dbdb1e7-85e6-45c8-8f39-46f982dd1464", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Now, lets deploy the tool to Unity catalog." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "85196db9-c0f8-4e86-8ad0-94653a1850f4", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n", + "from tools.sample_tool import sku_sample_translator\n", + "\n", + "client = DatabricksFunctionClient()\n", + "CATALOG = \"shared\" # Change me!\n", + "SCHEMA = \"cookbook_langgraph_udhay\" # Change me if you want\n", + "\n", + "# this will deploy the tool to UC, automatically setting the metadata in UC based on the tool's docstring & typing hints\n", + "tool_uc_info = client.create_python_function(func=sku_sample_translator, catalog=CATALOG, schema=SCHEMA, replace=True)\n", + "\n", + "# the tool will deploy to a function in UC called `{catalog}.{schema}.{func}` where {func} is the name of the function\n", + "# Print the deployed Unity Catalog function name\n", + "print(f\"Deployed Unity Catalog function name: {tool_uc_info.full_name}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "0a9bac0a-ac9b-4857-913f-3ca9d4a0e8ba", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Now, wrap it into a UCTool that will be used by our Agent. UC tool is just a Pydnatic base model that is serializable to YAML that will load the tool's metadata from UC and wrap it in a callable object." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6f9f657c-dea4-412e-b5e3-a6fb8ca3354e", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.tools.uc_tool import UCTool\n", + "\n", + "# wrap the tool into a UCTool which can be passed to our Agent\n", + "translate_sku_tool = UCTool(uc_function_name=tool_uc_info.full_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "84efc805-cc4d-4a8d-afca-ff9b37d794eb", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Now, let's test the UC tool - the UCTool is a directly callable wrapper around the UC function, so it can be used just like a local function, but the output will be put into a dictionary with either the output in a 'value' key or an 'error' key if an error is raised.\n", + "\n", + "when an error happens, the UC tool will also return an instruction prompt to show the agent how to think about handling the error. this can be changed via the `error_prompt` parameter in the UCTool..\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "276fd791-ff41-4440-bb89-895e5778b64f", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# successful call\n", + "translate_sku_tool(old_sku=\"OLD-XXX-1234\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d1e1a4ce-2f4b-4222-a1fb-c9b2f7f95d23", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# unsuccessful call\n", + "translate_sku_tool(old_sku=\"OxxLD-XXX-1234\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "3b44f626-6dc9-4117-a85a-79120569f4d3", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "now, let's convert our pytests to work with the UC tool. this requires a bit of transformation to the test code to account for the fact that the output is in a dictionary & exceptions are not raised directly." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "74e2f618-ecb7-44f1-bdbd-a3badac81e14", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%%writefile tools/test_sample_tool_uc.py\n", + "import pytest\n", + "from cookbook.tools.uc_tool import UCTool\n", + "\n", + "CATALOG = \"shared\" # Change me!\n", + "SCHEMA = \"cookbook_langgraph_udhay\" # Change me if you want\n", + "\n", + "# Load the function from the UCTool versus locally\n", + "@pytest.fixture\n", + "def uc_tool():\n", + " \"\"\"Fixture to translate a UC tool into a local function.\"\"\"\n", + " UC_FUNCTION_NAME = f\"{CATALOG}.{SCHEMA}.sku_sample_translator\"\n", + " loaded_tool = UCTool(uc_function_name=UC_FUNCTION_NAME)\n", + " return loaded_tool\n", + "\n", + "\n", + "# Note: The value will be post processed into the `value` key, so we must check the returned value there.\n", + "def test_valid_sku_translation(uc_tool):\n", + " \"\"\"Test successful SKU translation with valid input.\"\"\"\n", + " assert uc_tool(old_sku=\"OLD-ABC-1234\")[\"value\"] == \"NEW-1234-ABC\"\n", + " assert uc_tool(old_sku=\"OLD-XYZ-0001\")[\"value\"] == \"NEW-0001-XYZ\"\n", + " assert (\n", + " uc_tool(old_sku=\"old-def-5678\")[\"value\"] == \"NEW-5678-DEF\"\n", + " ) # Test case insensitivity\n", + "\n", + "\n", + "# Note: The value will be post processed into the `value` key, so we must check the returned value there.\n", + "def test_whitespace_handling(uc_tool):\n", + " \"\"\"Test that the function handles extra whitespace correctly.\"\"\"\n", + " assert uc_tool(old_sku=\" OLD-ABC-1234 \")[\"value\"] == \"NEW-1234-ABC\"\n", + " assert uc_tool(old_sku=\"\\tOLD-ABC-1234\\n\")[\"value\"] == \"NEW-1234-ABC\"\n", + "\n", + "\n", + "# Note: the input validation happens BEFORE the function is called by Spark, so we will never get these exceptions from the function.\n", + "# Instead, we will get invalid parameters errors from Spark.\n", + "def test_invalid_input_type(uc_tool):\n", + " \"\"\"Test that non-string inputs raise ValueError.\"\"\"\n", + " assert (\n", + " uc_tool(old_sku=123)[\"error\"][\"error_message\"]\n", + " == \"\"\"Invalid parameters provided: {'old_sku': \"Parameter old_sku should be of type STRING (corresponding python type ), but got \"}.\"\"\"\n", + " )\n", + " assert (\n", + " uc_tool(old_sku=None)[\"error\"][\"error_message\"]\n", + " == \"\"\"Invalid parameters provided: {'old_sku': \"Parameter old_sku should be of type STRING (corresponding python type ), but got \"}.\"\"\"\n", + " )\n", + "\n", + "\n", + "# Note: The errors will be post processed into the `error_message` key inside the `error` top level key, so we must check for exceptions there.\n", + "def test_invalid_prefix(uc_tool):\n", + " \"\"\"Test that SKUs not starting with 'OLD-' raise ValueError.\"\"\"\n", + " assert (\n", + " uc_tool(old_sku=\"NEW-ABC-1234\")[\"error\"][\"error_message\"]\n", + " == \"ValueError: SKU must start with 'OLD-'\"\n", + " )\n", + " assert (\n", + " uc_tool(old_sku=\"XXX-ABC-1234\")[\"error\"][\"error_message\"]\n", + " == \"ValueError: SKU must start with 'OLD-'\"\n", + " )\n", + "\n", + "\n", + "# Note: The errors will be post processed into the `error_message` key inside the `error` top level key, so we must check for exceptions there.\n", + "def test_invalid_format(uc_tool):\n", + " \"\"\"Test various invalid SKU formats.\"\"\"\n", + " invalid_skus = [\n", + " \"OLD-AB-1234\", # Too few letters\n", + " \"OLD-ABCD-1234\", # Too many letters\n", + " \"OLD-123-1234\", # Numbers instead of letters\n", + " \"OLD-ABC-123\", # Too few digits\n", + " \"OLD-ABC-12345\", # Too many digits\n", + " \"OLD-ABC-XXXX\", # Letters instead of numbers\n", + " \"OLD-A1C-1234\", # Mixed letters and numbers in middle\n", + " ]\n", + "\n", + " expected_error = \"ValueError: SKU format must be 'OLD-XXX-YYYY' where X is a letter and Y is a digit\"\n", + " for sku in invalid_skus:\n", + " assert uc_tool(old_sku=sku)[\"error\"][\"error_message\"] == expected_error\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "ec50dbed-a6e0-4d6d-8bdd-8d6736aec694", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import pytest\n", + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if du.is_in_databricks_notebook():\n", + " import sys\n", + " sys.dont_write_bytecode = True # Skip writing .pyc files to the bytecode cache on the cluster.\n", + "\n", + "# Run tests from test_sku_translator.py\n", + "pytest.main([\"-v\", \"tools/test_sample_tool_uc.py\"])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "2a17e7b0-53b6-434f-8e17-f4d3678f7c28", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "# Now, here's another example of a tool that executes python code." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "cc16221b-7ad6-4cd5-84c9-702319b7c260", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%%writefile tools/code_exec.py\n", + "def python_exec(code: str) -> str:\n", + " \"\"\"\n", + " Executes Python code in the sandboxed environment and returns its stdout. The runtime is stateless and you can not read output of the previous tool executions. i.e. No such variables \"rows\", \"observation\" defined. Calling another tool inside a Python code is NOT allowed.\n", + " Use only standard python libraries and these python libraries: bleach, chardet, charset-normalizer, defusedxml, googleapis-common-protos, grpcio, grpcio-status, jmespath, joblib, numpy, packaging, pandas, patsy, protobuf, pyarrow, pyparsing, python-dateutil, pytz, scikit-learn, scipy, setuptools, six, threadpoolctl, webencodings, user-agents, cryptography.\n", + "\n", + " Args:\n", + " code (str): Python code to execute. Remember to print the final result to stdout.\n", + "\n", + " Returns:\n", + " str: The output of the executed code.\n", + " \"\"\"\n", + " import sys\n", + " from io import StringIO\n", + "\n", + " sys_stdout = sys.stdout\n", + " redirected_output = StringIO()\n", + " sys.stdout = redirected_output\n", + " exec(code)\n", + " sys.stdout = sys_stdout\n", + " return redirected_output.getvalue()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5b583963-bb02-4b1d-947d-63ddc704b416", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from tools.code_exec import python_exec\n", + "\n", + "python_exec(\"print('hello')\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "43205daf-63a5-405e-86ee-26767c9291c3", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Test it locally" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "f9394e8f-218c-40c2-84af-085a87fdca28", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%%writefile tools/test_code_exec.py\n", + "\n", + "import pytest\n", + "from .code_exec import python_exec\n", + "\n", + "\n", + "def test_basic_arithmetic():\n", + " code = \"\"\"result = 2 + 2\\nprint(result)\"\"\"\n", + " assert python_exec(code).strip() == \"4\"\n", + "\n", + "\n", + "def test_multiple_lines():\n", + " code = \"x = 5\\n\" \"y = 3\\n\" \"result = x * y\\n\" \"print(result)\"\n", + " assert python_exec(code).strip() == \"15\"\n", + "\n", + "\n", + "def test_multiple_prints():\n", + " code = \"\"\"print('first')\\nprint('second')\\nprint('third')\\n\"\"\"\n", + " expected = \"first\\nsecond\\nthird\\n\"\n", + " assert python_exec(code) == expected\n", + "\n", + "\n", + "def test_using_pandas():\n", + " code = (\n", + " \"import pandas as pd\\n\"\n", + " \"data = {'col1': [1, 2], 'col2': [3, 4]}\\n\"\n", + " \"df = pd.DataFrame(data)\\n\"\n", + " \"print(df.shape)\"\n", + " )\n", + " assert python_exec(code).strip() == \"(2, 2)\"\n", + "\n", + "\n", + "def test_using_numpy():\n", + " code = \"import numpy as np\\n\" \"arr = np.array([1, 2, 3])\\n\" \"print(arr.mean())\"\n", + " assert python_exec(code).strip() == \"2.0\"\n", + "\n", + "\n", + "def test_syntax_error():\n", + " code = \"if True\\n\" \" print('invalid syntax')\"\n", + " with pytest.raises(SyntaxError):\n", + " python_exec(code)\n", + "\n", + "\n", + "def test_runtime_error():\n", + " code = \"x = 1 / 0\\n\" \"print(x)\"\n", + " with pytest.raises(ZeroDivisionError):\n", + " python_exec(code)\n", + "\n", + "\n", + "def test_undefined_variable():\n", + " code = \"print(undefined_variable)\"\n", + " with pytest.raises(NameError):\n", + " python_exec(code)\n", + "\n", + "\n", + "def test_multiline_string_manipulation():\n", + " code = \"text = '''\\n\" \"Hello\\n\" \"World\\n\" \"'''\\n\" \"print(text.strip())\"\n", + " expected = \"Hello\\nWorld\"\n", + " assert python_exec(code).strip() == expected\n", + "\n", + "# Will not fail locally, but will fail in UC.\n", + "# def test_unauthorized_flask():\n", + "# code = \"from flask import Flask\\n\" \"app = Flask(__name__)\\n\" \"print(app)\"\n", + "# with pytest.raises(ImportError):\n", + "# python_exec(code)\n", + "\n", + "\n", + "def test_no_print_statement():\n", + " code = \"x = 42\\n\" \"y = x * 2\"\n", + " assert python_exec(code) == \"\"\n", + "\n", + "\n", + "def test_calculation_without_print():\n", + " code = \"result = sum([1, 2, 3, 4, 5])\\n\" \"squared = [x**2 for x in range(5)]\"\n", + " assert python_exec(code) == \"\"\n", + "\n", + "\n", + "def test_function_definition_without_call():\n", + " code = \"def add(a, b):\\n\" \" return a + b\\n\" \"result = add(3, 4)\"\n", + " assert python_exec(code) == \"\"\n", + "\n", + "\n", + "def test_class_definition_without_instantiation():\n", + " code = (\n", + " \"class Calculator:\\n\"\n", + " \" def add(self, a, b):\\n\"\n", + " \" return a + b\\n\"\n", + " \"calc = Calculator()\"\n", + " )\n", + " assert python_exec(code) == \"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "681450a6-84a6-4a76-b734-403e569de7b3", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import pytest\n", + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if du.is_in_databricks_notebook():\n", + " import sys\n", + " sys.dont_write_bytecode = True # Skip writing .pyc files to the bytecode cache on the cluster.\n", + "\n", + "# Run tests from test_code_exec.py\n", + "pytest.main([\"-v\", \"tools/test_code_exec.py\"])\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "10a2bece-4e96-4201-b3f8-42054875f10f", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Deploy to UC" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "27f1ab60-0d37-447f-adf1-40ba52794398", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n", + "from tools.code_exec import python_exec\n", + "from cookbook.tools.uc_tool import UCTool\n", + "\n", + "client = DatabricksFunctionClient()\n", + "CATALOG = \"shared\" # Change me!\n", + "SCHEMA = \"cookbook_langgraph_udhay\" # Change me if you want\n", + "\n", + "# this will deploy the tool to UC, automatically setting the metadata in UC based on the tool's docstring & typing hints\n", + "python_exec_tool_uc_info = client.create_python_function(func=python_exec, catalog=CATALOG, schema=SCHEMA, replace=True)\n", + "\n", + "# the tool will deploy to a function in UC called `{catalog}.{schema}.{func}` where {func} is the name of the function\n", + "# Print the deployed Unity Catalog function name\n", + "print(f\"Deployed Unity Catalog function name: {python_exec_tool_uc_info.full_name}\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "db12518d-8c11-4bc8-8d47-9996d54adc9a", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Test as UC Tool for the Agent" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "cb292ee7-b2ff-43d2-9fa1-fb84767970bf", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.tools.uc_tool import UCTool\n", + "\n", + "\n", + "# wrap the tool into a UCTool which can be passed to our Agent\n", + "python_exec_tool = UCTool(uc_function_name=python_exec_tool_uc_info.full_name)\n", + "\n", + "python_exec_tool(code=\"print('hello')\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "9210666e-1849-4885-aa1a-8e0b912f57f0", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "New tests" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "85df2d06-6978-48ba-b9de-143f02374206", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%%writefile tools/test_code_exec_as_uc_tool.py\n", + "\n", + "import pytest\n", + "from cookbook.tools.uc_tool import UCTool\n", + "\n", + "CATALOG = \"shared\" # Change me!\n", + "SCHEMA = \"cookbook_langgraph_udhay\" # Change me if you want\n", + "\n", + "\n", + "@pytest.fixture\n", + "def python_exec():\n", + " \"\"\"Fixture to provide the python_exec function from UCTool.\"\"\"\n", + " python_exec_tool = UCTool(uc_function_name=f\"{CATALOG}.{SCHEMA}.python_exec\")\n", + " return python_exec_tool\n", + "\n", + "\n", + "def test_basic_arithmetic(python_exec):\n", + " code = \"\"\"result = 2 + 2\\nprint(result)\"\"\"\n", + " assert python_exec(code=code)[\"value\"].strip() == \"4\"\n", + "\n", + "\n", + "def test_multiple_lines(python_exec):\n", + " code = \"x = 5\\n\" \"y = 3\\n\" \"result = x * y\\n\" \"print(result)\"\n", + " assert python_exec(code=code)[\"value\"].strip() == \"15\"\n", + "\n", + "\n", + "def test_multiple_prints(python_exec):\n", + " code = \"\"\"print('first')\\nprint('second')\\nprint('third')\\n\"\"\"\n", + " expected = \"first\\nsecond\\nthird\\n\"\n", + " assert python_exec(code=code)[\"value\"] == expected\n", + "\n", + "\n", + "def test_using_pandas(python_exec):\n", + " code = (\n", + " \"import pandas as pd\\n\"\n", + " \"data = {'col1': [1, 2], 'col2': [3, 4]}\\n\"\n", + " \"df = pd.DataFrame(data)\\n\"\n", + " \"print(df.shape)\"\n", + " )\n", + " assert python_exec(code=code)[\"value\"].strip() == \"(2, 2)\"\n", + "\n", + "\n", + "def test_using_numpy(python_exec):\n", + " code = \"import numpy as np\\n\" \"arr = np.array([1, 2, 3])\\n\" \"print(arr.mean())\"\n", + " assert python_exec(code=code)[\"value\"].strip() == \"2.0\"\n", + "\n", + "\n", + "def test_syntax_error(python_exec):\n", + " code = \"if True\\n\" \" print('invalid syntax')\"\n", + " result = python_exec(code=code)\n", + " assert \"Syntax error at or near 'invalid'.\" in result[\"error\"][\"error_message\"]\n", + "\n", + "\n", + "def test_runtime_error(python_exec):\n", + " code = \"x = 1 / 0\\n\" \"print(x)\"\n", + " result = python_exec(code=code)\n", + " assert \"ZeroDivisionError\" in result[\"error\"][\"error_message\"]\n", + "\n", + "\n", + "def test_undefined_variable(python_exec):\n", + " code = \"print(undefined_variable)\"\n", + " result = python_exec(code=code)\n", + " assert \"NameError\" in result[\"error\"][\"error_message\"]\n", + "\n", + "\n", + "def test_multiline_string_manipulation(python_exec):\n", + " code = \"text = '''\\n\" \"Hello\\n\" \"World\\n\" \"'''\\n\" \"print(text.strip())\"\n", + " expected = \"Hello\\nWorld\"\n", + " assert python_exec(code=code)[\"value\"].strip() == expected\n", + "\n", + "\n", + "def test_unauthorized_flask(python_exec):\n", + " code = \"from flask import Flask\\n\" \"app = Flask(__name__)\\n\" \"print(app)\"\n", + " result = python_exec(code=code)\n", + " assert (\n", + " \"ModuleNotFoundError: No module named 'flask'\"\n", + " in result[\"error\"][\"error_message\"]\n", + " )\n", + "\n", + "\n", + "def test_no_print_statement(python_exec):\n", + " code = \"x = 42\\n\" \"y = x * 2\"\n", + " assert python_exec(code=code)[\"value\"] == \"\"\n", + "\n", + "\n", + "def test_calculation_without_print(python_exec):\n", + " code = \"result = sum([1, 2, 3, 4, 5])\\n\" \"squared = [x**2 for x in range(5)]\"\n", + " assert python_exec(code=code)[\"value\"] == \"\"\n", + "\n", + "\n", + "def test_function_definition_without_call(python_exec):\n", + " code = \"def add(a, b):\\n\" \" return a + b\\n\" \"result = add(3, 4)\"\n", + " assert python_exec(code=code)[\"value\"] == \"\"\n", + "\n", + "\n", + "def test_class_definition_without_instantiation(python_exec):\n", + " code = (\n", + " \"class Calculator:\\n\"\n", + " \" def add(self, a, b):\\n\"\n", + " \" return a + b\\n\"\n", + " \"calc = Calculator()\"\n", + " )\n", + " assert python_exec(code=code)[\"value\"] == \"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3571724d-9a65-4743-8f73-2c94c2d8334d", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import pytest\n", + "\n", + "if du.is_in_databricks_notebook():\n", + " import sys\n", + " sys.dont_write_bytecode = True # Skip writing .pyc files to the bytecode cache on the cluster.\n", + "\n", + "# Run tests from test_code_exec_as_uc_tool.py\n", + "pytest.main([\"-v\", \"tools/test_code_exec_as_uc_tool.py\"])\n", + "\n" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "03_create_tools", + "widgets": {} + }, + "kernelspec": { + "display_name": "genai-cookbook-T2SdtsNM-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/langgraph_agent_app_sample_code/04_tool_calling_agent.ipynb b/langgraph_agent_app_sample_code/04_tool_calling_agent.ipynb new file mode 100644 index 0000000..c55a23e --- /dev/null +++ b/langgraph_agent_app_sample_code/04_tool_calling_agent.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "31661828-f9bb-4fc2-a1bd-94424a27ed52", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "## 👉 START HERE: How to use this notebook\n", + "\n", + "# Step 3: Build, evaluate, & deploy your Agent\n", + "\n", + "Use this notebook to iterate on the code and configuration of your Agent.\n", + "\n", + "By the end of this notebook, you will have 1+ registered versions of your Agent, each coupled with a detailed quality evaluation.\n", + "\n", + "Optionally, you can deploy a version of your Agent that you can interact with in the [Mosiac AI Playground](https://docs.databricks.com/en/large-language-models/ai-playground.html) and let your business stakeholders who don't have Databricks accounts interact with it & provide feedback in the [Review App](https://docs.databricks.com/en/generative-ai/agent-evaluation/human-evaluation.html#review-app-ui).\n", + "\n", + "\n", + "For each version of your agent, you will have an MLflow run inside your MLflow experiment that contains:\n", + "- Your Agent's code & config\n", + "- Evaluation metrics for cost, quality, and latency" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5d9f685a-fdb7-49a4-9e3a-a4a9e964d045", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "\n", + "**Important note:** Throughout this notebook, we indicate which cell's code you:\n", + "- ✅✏️ should customize - these cells contain code & config with business logic that you should edit to meet your requirements & tune quality.\n", + "- 🚫✏️ should not customize - these cells contain boilerplate code required to load/save/execute your Agent\n", + "\n", + "*Cells that don't require customization still need to be run! You CAN change these cells, but if this is the first time using this notebook, we suggest not doing so.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "bb4f8cc0-1797-4beb-a9f2-df21a9db79f0", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Install Python libraries\n", + "\n", + "You do not need to modify this cell unless you need additional Python packages in your Agent." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6d4030e8-ae97-4351-bebd-9651d283578f", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install -qqqq -U -r requirements.txt\n", + "# Restart to load the packages into the Python environment\n", + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "9ffbdc10-6cc3-4174-9314-16f9da64f2ab", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Connect to Databricks\n", + "\n", + "If running locally in an IDE using Databricks Connect, connect the Spark client & configure MLflow to use Databricks Managed MLflow. If this running in a Databricks Notebook, these values are already set." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "ec8db50c-9a1c-4c0d-8fe9-953d8c3c13ed", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from mlflow.utils import databricks_utils as du\n", + "\n", + "if not du.is_in_databricks_notebook():\n", + " from databricks.connect import DatabricksSession\n", + " import os\n", + "\n", + " spark = DatabricksSession.builder.getOrCreate()\n", + " os.environ[\"MLFLOW_TRACKING_URI\"] = \"databricks\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a5f3a4b3-8d55-4d2a-8f96-7bdf9480f98d", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Load the Agent's UC storage locations; set up MLflow experiment\n", + "\n", + "This notebook uses the UC model, MLflow Experiment, and Evaluation Set that you specified in the [Agent setup](02_agent_setup.ipynb) notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "56766f86-7862-45f0-bd1d-c947a1611dbe", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import yaml\n", + "from pathlib import Path\n", + "import mlflow \n", + "from box import Box\n", + "from cookbook.databricks_utils import get_table_url\n", + "from cookbook.databricks_utils import get_mlflow_experiment_url\n", + "\n", + "# Load the Agent's storage configuration\n", + "agent_storage_config = Box(yaml.safe_load(Path(\"./configs/agent_storage_config.yaml\").read_text()))\n", + "print(agent_storage_config)\n", + "\n", + "# set the MLflow experiment\n", + "experiment_info = mlflow.set_experiment(agent_storage_config.mlflow_experiment_name)\n", + "# If running in a local IDE, set the MLflow experiment name as an environment variable\n", + "os.environ[\"MLFLOW_EXPERIMENT_NAME\"] = agent_storage_config.mlflow_experiment_name\n", + "\n", + "print(f\"View the MLflow Experiment `{agent_storage_config.mlflow_experiment_name}` at {get_mlflow_experiment_url(experiment_info.experiment_id)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "48c580a0-105e-4a6a-b18d-c3210deff6d5", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "### 🚫✏️ Helper method to log the Agent's code & config to MLflow\n", + "\n", + "Before we start, let's define a helper method to log the Agent's code & config to MLflow. We will use this to log the agent's code & config to MLflow & the Unity Catalog. It is used in evaluation & for deploying to Agent Evaluation's [Review App](https://docs.databricks.com/en/generative-ai/agent-evaluation/human-evaluation.html#review-app-ui) (a chat UI for your stakeholders to test this agent) and later, deplying the Agent to production." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "dcbd54c7-61e0-480d-9b96-f579151b063c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import mlflow\n", + "from mlflow.types.llm import CHAT_MODEL_INPUT_SCHEMA\n", + "from mlflow.models.rag_signatures import StringResponse, ChatCompletionResponse\n", + "from cookbook.agents.utils.signatures import STRING_RESPONSE_WITH_MESSAGES\n", + "from mlflow.models.signature import ModelSignature\n", + "\n", + "\n", + "# This helper will log the Agent's code & config to an MLflow run and return the logged model's URI\n", + "# If run from inside a mlfow.start_run() block, it will log to that run, otherwise it will log to a new run.\n", + "# This logged Agent is ready for deployment, so if you are happy with your evaluation, it is ready to deploy!\n", + "def log_function_calling_agent_to_mlflow(agent_config):\n", + " from cookbook.agents.function_calling_agent import get_resource_dependencies\n", + "\n", + " # Get the agent's code path from the imported Agent class\n", + " agent_code_path = f\"{os.getcwd()}/cookbook/agents/function_calling_agent.py\"\n", + "\n", + " # Get the pip requirements from the requirements.txt file\n", + " with open(\"requirements.txt\", \"r\") as file:\n", + " pip_requirements = [line.strip() for line in file.readlines()] + [\n", + " \"pyspark\"\n", + " ] # manually add pyspark\n", + "\n", + " logged_agent_info = mlflow.langchain.log_model(\n", + " agent_code_path,\n", + " artifact_path=\"agent\",\n", + " input_example=agent_config.input_example,\n", + " model_config=agent_config.to_dict(),\n", + " resources=get_resource_dependencies(\n", + " agent_config\n", + " ), # This allows the agents.deploy() command to securely provision credentials for the Agent's databricks resources e.g., vector index, model serving endpoints, etc\n", + " signature=ModelSignature(\n", + " inputs=CHAT_MODEL_INPUT_SCHEMA,\n", + " # outputs=STRING_RESPONSE_WITH_MESSAGES #TODO: replace with MLflow signature\n", + " outputs=ChatCompletionResponse(),\n", + " ),\n", + " code_paths=[os.path.join(os.getcwd(), \"cookbook\")],\n", + " pip_requirements=pip_requirements,\n", + " )\n", + "\n", + " return logged_agent_info" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "9933d05f-29fa-452e-abdc-2a02328fbe22", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "\n", + "## 1️⃣ Iterate on the Agent's code & config to improve quality\n", + "\n", + "The below cells are used to execute your inner dev loop to improve the Agent's quality.\n", + "\n", + "We suggest the following process:\n", + "1. Vibe check the Agent for 5 - 10 queries to verify it works\n", + "2. Make any necessary changes to the code/config\n", + "3. Use Agent Evaluation to evaluate the Agent using your evaluation set, which will provide a quality assessment & identify the root causes of any quality issues\n", + "4. Based on that evaluation, make & test changes to the code/config to improve quality\n", + "5. 🔁 Repeat steps 3 and 4 until you are satisified with the Agent's quality\n", + "6. Deploy the Agent to Agent Evaluation's [Review App](https://docs.databricks.com/en/generative-ai/agent-evaluation/human-evaluation.html#review-app-ui) for pre-production testing\n", + "7. Use the following notebooks to review that feedback (optionally adding new records to your evaluation set) & identify any further quality issues\n", + "8. 🔁 Repeat steps 3 and 4 to fix any issues identified in step 7\n", + "9. Deploy the Agent to a production-ready REST API endpoint (using the same cells in this notebook as step 6)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "908aedb9-8258-4ae6-9571-ba8523344db6", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### ✅✏️ Optionally, adjust the Agent's code\n", + "\n", + "Here, we import the Agent's code so we can run the Agent locally within the notebook. To modify the code, open the Agent's code file in a separate window, enable reload, make your changes, and re-run this cell.\n", + "\n", + "**Typically, when building the first version of your agent, we suggest first trying to tune the configuration (prompts, etc) to improve quality. If you need more control to fix quality issues, you can then modify the Agent's code.**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7bcafb24-6c6d-4176-8b0e-beb5e4377b61", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from cookbook.agents.function_calling_agent import create_function_calling_agent\n", + "import inspect\n", + "\n", + "# Print the Agent code for inspection\n", + "print(inspect.getsource(create_function_calling_agent))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5ce51671-314c-48ef-929d-c9e10f19e75c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 3" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6113559c-2fd7-4d1b-9514-b0c4be60f3d1", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### ✅✏️ 🅰 Vibe check the Agent for a single query\n", + "\n", + "Running this cell will produce an MLflow Trace that you can use to see the Agent's outputs and understand the steps it took to produce that output.\n", + "\n", + "If you are running in a local IDE, browse to the MLflow Experiment page to view the Trace (link to the Experiment UI is at the top of this notebook). If running in a Databricks Notebook, your trace will appear inline below." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7d371b6b-2f46-4c47-afcb-679c7134a1fa", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import yaml\n", + "from pathlib import Path\n", + "from box import Box\n", + "from cookbook.databricks_utils import get_mlflow_experiment_traces_url\n", + "from cookbook.agents.function_calling_agent import create_function_calling_agent\n", + "\n", + "\n", + "import os \n", + "import os.path as path\n", + "from box import Box\n", + "\n", + "agent_conf = Box(yaml.safe_load(Path(\"configs/function_calling_agent_config.yaml\").read_text()))\n", + "\n", + "print(agent_conf)\n", + "\n", + "# Load the Agent's code with the above configuration\n", + "agent = create_function_calling_agent(agent_config=agent_conf)\n", + "\n", + "# Vibe check the Agent for a single query\n", + "output = agent.invoke(input={\"messages\": [{\"role\": \"user\", \"content\": \"How does the blender work?\"}]})\n", + "# output = agent.predict(model_input={\"messages\": [{\"role\": \"user\", \"content\": \"Translate the sku `OLD-abs-1234` to the new format\"}]})\n", + "\n", + "print(f\"View the MLflow Traces at {get_mlflow_experiment_traces_url(experiment_info.experiment_id)}\")\n", + "print(f\"Agent's final response:\\n----\\n{output['choices'][-1]['message']['content']}\\n----\")\n", + "print()\n", + "# print(f\"Agent's full message history (useful for debugging):\\n----\\n{json.dumps(output['messages'], indent=2)}\\n----\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "896a841c-aa15-43e0-af80-84fa2da5b02c", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Now, let's test a multi-turn conversation with the Agent." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3ec8fac0-a3b0-4c2e-8339-2507d1d665e1", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "second_turn = {'messages': [output['choices'][-1]['message']['content']] + [{\"role\": \"user\", \"content\": \"How do I turn it on?\"}]}\n", + "\n", + "# Run the Agent again with the same input to continue the conversation\n", + "second_turn_output = agent.invoke(input=second_turn)\n", + "\n", + "print(f\"View the MLflow Traces at {get_mlflow_experiment_traces_url(experiment_info.experiment_id)}\")\n", + "print(f\"Agent's final response:\\n----\\n{second_turn_output['choices'][-1]['message']['content']}\\n----\")\n", + "print()\n", + "#print(f\"Agent's full message history (useful for debugging):\\n----\\n{json.dumps(second_turn_output['messages'], indent=2)}\\n----\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "c1fdead4-9366-478c-bfd6-452c4df994e3", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "#### ✅✏️ 🅱 Evaluate the Agent using your evaluation set\n", + "\n", + "Note: If you do not have an evaluation set, you can create a synthetic evaluation set by using the 03_synthetic_evaluation notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "96c94766-d832-4493-b5cc-30cfcfecdbce", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "evaluation_set = spark.table(agent_storage_config.evaluation_set_uc_table)\n", + "\n", + "with mlflow.start_run():\n", + " logged_agent_info = log_function_calling_agent_to_mlflow(agent_conf)\n", + "\n", + " # Run the agent for these queries, using Agent evaluation to parallelize the calls\n", + " eval_results = mlflow.evaluate(\n", + " model=logged_agent_info.model_uri, # use the MLflow logged Agent\n", + " data=evaluation_set, # Evaluate the Agent for every row of the evaluation set\n", + " model_type=\"databricks-agent\", # use Agent Evaluation\n", + " )\n", + "\n", + " # Show all outputs. Click on a row in this table to display the MLflow Trace.\n", + " display(eval_results.tables[\"eval_results\"])\n", + "\n", + " # Click 'View Evaluation Results' to see the Agent's inputs/outputs + quality evaluation displayed in a UI" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "92d5719a-b5ed-44a9-a758-f571e3fd0f65", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "## 2️⃣ Deploy a version of your Agent - either to the Review App or Production\n", + "\n", + "Once you have a version of your Agent that has sufficient quality, you will register the Agent's model from the MLflow Experiment into the Unity Catalog & use Agent Framework's `agents.deploy(...)` command to deploy it. Note these steps are the same for deploying to pre-production (e.g., the [Review App](https://docs.databricks.com/en/generative-ai/agent-evaluation/human-evaluation.html#review-app-ui) or production.\n", + "\n", + "By the end of this step, you will have deployed a version of your Agent that you can interact with and share with your business stakeholders for feedback, even if they don't have access to your Databricks workspace:\n", + "\n", + "1. A production-ready scalable REST API deployed as a Model Serving endpoint that logged every request/request/MLflow Trace to a Delta Table.\n", + " - REST API for querying the Agent\n", + " - REST API for sending user feedback from your UI to the Agent\n", + "2. Agent Evaluation's [Review App](https://docs.databricks.com/en/generative-ai/agent-evaluation/human-evaluation.html#review-app-ui) connected to these endpoints.\n", + "3. [Mosiac AI Playground](https://docs.databricks.com/en/large-language-models/ai-playground.html) connected to these endpoints." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "469d1dbb-0a55-4a6b-9a08-46650129614e", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "source": [ + "Option 1: Deploy the last agent you logged above" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "497da3db-7394-41ce-beff-a26a89d78bab", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "from databricks import agents\n", + "\n", + "# Use Unity Catalog as the model registry\n", + "mlflow.set_registry_uri(\"databricks-uc\")\n", + "\n", + "# Register the Agent's model to the Unity Catalog\n", + "uc_registered_model_info = mlflow.register_model(\n", + " model_uri=logged_agent_info.model_uri, name=agent_storage_config.uc_model_name\n", + ")\n", + "\n", + "# Deploy the model to the review app and a model serving endpoint\n", + "agents.deploy(agent_storage_config.uc_model_name, uc_registered_model_info.version)" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "04_tool_calling_agent", + "widgets": {} + }, + "kernelspec": { + "display_name": "genai-cookbook-T2SdtsNM-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/langgraph_agent_app_sample_code/README.md b/langgraph_agent_app_sample_code/README.md new file mode 100644 index 0000000..7c64448 --- /dev/null +++ b/langgraph_agent_app_sample_code/README.md @@ -0,0 +1,27 @@ +# How to use local IDE + +- databricks auth profile DEFAULT is set up +``` +databricks auth profile login +``` +- add a cluster_id in ~/.databrickscfg (if you want to use Spark code) +- add `openai_sdk_agent_app_sample_code/.env` to point to mlflow exp + dbx tracking uri (if you want to run any agent code from the terminal and have it logged to mlflow). Make sure this mlflow experiment maps to the one in 02_agent_setup.ipynb. +``` +MLFLOW_TRACKING_URI=databricks +MLFLOW_EXPERIMENT_NAME=/Users/your.name@company.com/my_agent_mlflow_experiment +``` +- install poetry env & activate in your IDE +``` +poetry install +``` + +if you want to use the data pipeline code in spark, you need to build the cookbook wheel and install it in the cluster +- build cookbook wheel +``` +poetry build +``` +- install cookbook wheel in cluster + - Copy the wheel file to a UC Volume or Workspace folder + - Go to the cluster's Libraries page and install the wheel file as a new library + + diff --git a/langgraph_agent_app_sample_code/__init__.py b/langgraph_agent_app_sample_code/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/configs/README.md b/langgraph_agent_app_sample_code/configs/README.md new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/configs/agent_storage_config.yaml b/langgraph_agent_app_sample_code/configs/agent_storage_config.yaml new file mode 100644 index 0000000..9bef039 --- /dev/null +++ b/langgraph_agent_app_sample_code/configs/agent_storage_config.yaml @@ -0,0 +1,3 @@ +evaluation_set_uc_table: shared.cookbook_langgraph_udhay.my_agent_2_eval_set +mlflow_experiment_name: /Users/udhayaraj.sivalingam@databricks.com/my_agent_2_mlflow_experiment +uc_model_name: shared.cookbook_langgraph_udhay.my_agent_2 \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/configs/data_pipeline_config.yaml b/langgraph_agent_app_sample_code/configs/data_pipeline_config.yaml new file mode 100644 index 0000000..16c3bf3 --- /dev/null +++ b/langgraph_agent_app_sample_code/configs/data_pipeline_config.yaml @@ -0,0 +1,15 @@ +# Choose a Unity Catalog Volume containing PDF, HTML, etc documents to be parsed/chunked/embedded. +chunking_config: + chunk_overlap_tokens: 256 + chunk_size_tokens: 1024 + embedding_model_endpoint: databricks-gte-large-en +output: + chunked_docs_table: test_product_docs_docs_chunked__v2 + parsed_docs_table: test_product_docs_docs__v2 + vector_index: test_product_docs_docs_chunked_index__v2 + vector_search_endpoint: one-env-shared-endpoint-15 +source: + uc_catalog_name: shared + uc_schema_name: cookbook_langgraph_udhay + uc_volume_name: product_docs + volume_path: /Volumes/shared/cookbook_langgraph_udhay/product_docs diff --git a/langgraph_agent_app_sample_code/configs/function_calling_agent_config.yaml b/langgraph_agent_app_sample_code/configs/function_calling_agent_config.yaml new file mode 100644 index 0000000..7e4f982 --- /dev/null +++ b/langgraph_agent_app_sample_code/configs/function_calling_agent_config.yaml @@ -0,0 +1,67 @@ +class_path: cookbook.config.agents.function_calling_agent.FunctionCallingAgentConfig +input_example: + messages: + - content: What can you help me with? + role: user +llm_config: + type: llm + llm_endpoint_name: agents-demo-gpt4o-mini + llm_parameters: + max_tokens: 1500 + temperature: 0.01 + llm_system_prompt_template: "## Role\nYou are a helpful assistant that answers questions\ + \ using a set of tools. If needed, you ask the user follow-up questions to clarify\ + \ their request.\n\n## Objective\nYour goal is to provide accurate, relevant,\ + \ and helpful response based solely on the outputs from these tools. You are concise\ + \ and direct in your responses.\n\n## Instructions\n1. **Understand the Query**:\ + \ Think step by step to analyze the user's question and determine the core need\ + \ or problem. \n\n2. **Assess available tools**: Think step by step to consider\ + \ each available tool and understand their capabilities in the context of the\ + \ user's query.\n\n3. **Select the appropriate tool(s) OR ask follow up questions**:\ + \ Based on your understanding of the query and the tool descriptions, decide which\ + \ tool(s) should be used to generate a response. If you do not have enough information\ + \ to use the available tools to answer the question, ask the user follow up questions\ + \ to refine their request. If you do not have a relevant tool for a question\ + \ or the outputs of the tools are not helpful, respond with: \"I'm sorry, I can't\ + \ help you with that.\"" +tool_configs: +- class_path: cookbook.tools.vector_search.VectorSearchRetrieverTool + type: vector_search + description: Use this tool to search for product documentation. + doc_similarity_threshold: 0.0 + filterable_columns: [] + name: search_product_docs + retriever_filter_parameter_prompt: optional filters to apply to the search. An array + of objects, each specifying a field name and the filters to apply to that field. + retriever_query_parameter_prompt: query to look up in retriever + vector_search_index: shared.cookbook_langgraph_udhay.test_product_docs_docs_chunked_index__v2 + vector_search_endpoint: one-env-shared-endpoint-15 + embedding_endpoint_name: databricks-gte-large-en + vector_search_parameters: + num_results: 5 + query_type: ann + vector_search_schema: + additional_metadata_columns: [] + chunk_text: content_chunked + document_uri: doc_uri + id_column: chunk_id +- class_path: cookbook.tools.uc_tool.UCTool + type: uc_function + error_prompt: 'The tool call generated an Exception, detailed in `error`. Think + step-by-step following these instructions to determine your next step. + + [1] Is the error due to a problem with the input parameters? + + [2] Could it succeed if retried with exactly the same inputs? + + [3] Could it succeed if retried with modified parameters using the input we already + have from the user? + + [4] Could it succeed if retried with modified parameters informed by collecting + additional input from the user? What specific input would we need from the user? + + Based on your thinking, if the error is due to a problem with the input parameters, + either call this tool again in a way that avoids this exception or collect additional + information from the user to modify the inputs to avoid this exception.' + function_name: shared.cookbook_local_test_udhay.sku_sample_translator + diff --git a/langgraph_agent_app_sample_code/cookbook/__init__.py b/langgraph_agent_app_sample_code/cookbook/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/agents/__init__.py b/langgraph_agent_app_sample_code/cookbook/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/agents/function_calling_agent.py b/langgraph_agent_app_sample_code/cookbook/agents/function_calling_agent.py new file mode 100644 index 0000000..dda16f8 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/agents/function_calling_agent.py @@ -0,0 +1,252 @@ +# In this file, we construct a function-calling Agent with a Retriever tool using MLflow + langgraph. +import yaml +from pathlib import Path +import logging +import json +from dataclasses import asdict +from functools import reduce +from typing import Iterator, Dict, List, Optional, Union, Any + +from databricks_langchain import ChatDatabricks +from databricks_langchain import DatabricksVectorSearch +from langchain_core.messages import ( + AIMessage, + MessageLikeRepresentation, +) +from langchain_core.runnables import RunnableGenerator +from langchain_core.runnables.base import RunnableSequence +from langchain_core.tools import tool, Tool +from langgraph.prebuilt import create_react_agent +from mlflow.models import set_model +from mlflow.models.rag_signatures import ( + ChatCompletionResponse, + ChainCompletionChoice, + Message, +) +from langchain_core.messages import ( + AIMessage, + HumanMessage, + ToolMessage, + MessageLikeRepresentation, +) +from mlflow.models.resources import ( + DatabricksResource, + DatabricksServingEndpoint, + DatabricksVectorSearchIndex, + DatabricksFunction, +) +from pydantic import BaseModel +from unitycatalog.ai.core.databricks import DatabricksFunctionClient +from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit +from cookbook.agents.utils.load_config import load_config + +logging.basicConfig(level=logging.INFO) + + +FC_AGENT_DEFAULT_YAML_CONFIG_FILE_NAME = "function_calling_agent_config.yaml" + + +def create_tool(tool_config) -> Tool: + if tool_config.type == "vector_search": + vector_search_as_retriever = DatabricksVectorSearch( + endpoint=tool_config.vector_search_endpoint, + index_name=tool_config.vector_search_index, + columns=[tool_config.vector_search_schema.chunk_text, tool_config.vector_search_schema.id_column, tool_config.vector_search_schema.document_uri], + ).as_retriever(search_kwargs=tool_config.vector_search_parameters) + + @tool + def search_product_docs(question: str): + """Use this tool to search for databricks product documentation.""" + relevant_docs = vector_search_as_retriever.get_relevant_documents(question) + chunk_template = "Passage: {chunk_text}\n" + chunk_contents = [ + chunk_template.format( + chunk_text=doc.page_content, + ) + for doc in relevant_docs + ] + return "".join(chunk_contents) + + return search_product_docs + elif tool_config.type == "uc_function": + client = DatabricksFunctionClient() + toolkit = UCFunctionToolkit( + client=client, function_names=[tool_config.function_name] + ) + return toolkit.tools[-1] + else: + raise ValueError(f"Unknown tool type: {tool_config.type}") + +def create_chat_completion_response(message: Message) -> Dict: + return asdict(ChatCompletionResponse( + choices=[ChainCompletionChoice(message=message)], + )) + +def stringify_tool_call(tool_call: Dict[str, Any]) -> str: + """Convert a raw tool call into a formatted string that the playground UI expects""" + try: + request = json.dumps( + { + "id": tool_call.get("id"), + "name": tool_call.get("name"), + "arguments": str(tool_call.get("args", {})), + }, + indent=2, + ) + return f"{request}" + except: + # for non UC functions, return the string representation of tool calls + # you can modify this to return a different format + return str(tool_call) + + +def stringify_tool_result(tool_msg: ToolMessage) -> str: + """Convert a ToolMessage into a formatted string that the playground UI expects""" + try: + result = json.dumps( + {"id": tool_msg.tool_call_id, "content": tool_msg.content}, indent=2 + ) + return f"{result}" + except: + # for non UC functions, return the string representation of tool message + # you can modify this to return a different format + return str(tool_msg) + + +def parse_message(msg) -> Message: + """Parse different message types into their string representations""" + # tool call result + if isinstance(msg, ToolMessage): + return Message(role="tool", content=stringify_tool_result(msg)) + # tool call + elif isinstance(msg, AIMessage) and msg.tool_calls: + tool_call_results = [stringify_tool_call(call) for call in msg.tool_calls] + return Message(role="system", content="".join(tool_call_results)) + # normal HumanMessage or AIMessage (reasoning or final answer) + elif isinstance(msg, AIMessage): + return Message(role="system", content=msg.content) + elif isinstance(msg, HumanMessage): + return Message(role="user", content=msg.content) + else: + print(f"Unexpected message type: {type(msg)}") + return Message(role="unknown", content=str(msg)) + + +def wrap_output(stream: Iterator[MessageLikeRepresentation]) -> Iterator[Dict]: + """ + Process and yield formatted outputs from the message stream. + The invoke and stream langchain functions produce different output formats. + This function handles both cases. + """ + for event in stream: + # the agent was called with invoke() + if "messages" in event: + messages = event["messages"] + # output_content = "" + # for msg in event["messages"]: + # output_content += parse_message(msg) + "\n\n" + yield create_chat_completion_response(parse_message(messages[-1])) + # the agent was called with stream() + else: + for node in event: + for key, messages in event[node].items(): + if isinstance(messages, list): + for msg in messages: + yield create_chat_completion_response(parse_message(msg)) + else: + print( + "Unexpected value {messages} for key {key}. Expected a list of `MessageLikeRepresentation`'s" + ) + yield create_chat_completion_response(Message(content=str(messages))) + + +def create_resource_dependency(config: BaseModel) -> List[DatabricksResource]: + if config.type == "llm": + return [DatabricksServingEndpoint(endpoint_name=config.llm_endpoint_name)] + elif config.type == "vector_search": + return [ + DatabricksVectorSearchIndex(index_name=config.vector_search_index), + DatabricksServingEndpoint(config.embedding_endpoint_name), + ] + elif config.type == "uc_function": + return [DatabricksFunction(function_name=config.function_name)] + else: + raise ValueError(f"Unknown config type: {type(config)}") + + +def get_resource_dependencies( + agent_config, +) -> List[DatabricksResource]: + configs = [agent_config.llm_config] + agent_config.tool_configs + dependencies = reduce(lambda x, y: x + y, map(create_resource_dependency, configs)) + return dependencies + + +def create_function_calling_agent( + agent_config +) -> RunnableSequence: + if not agent_config: + raise ( + f"No agent config found. If you are in your local development environment, make sure you either [1] are calling init(agent_config=...) with either a loaded agent config or the full path to a YAML config file or [2] have a YAML config file saved at {{your_project_root_folder}}/configs/{FC_AGENT_DEFAULT_YAML_CONFIG_FILE_NAME}." + ) + + tool_configs = agent_config.tool_configs + tools = list(map(create_tool, tool_configs)) + + chat_model = ChatDatabricks( + endpoint=agent_config.llm_config.llm_endpoint_name, + temerature=agent_config.llm_config.llm_parameters.temperature, + max_tokens=agent_config.llm_config.llm_parameters.max_tokens, + ) + + react_agent = create_react_agent( + chat_model, + tools, + messages_modifier=agent_config.llm_config.llm_system_prompt_template, + ) | RunnableGenerator(wrap_output) + + logging.info("Successfully loaded agent config in __init__.") + + return react_agent + +agent_conf = load_config() + +print(agent_conf) +# tell MLflow logging where to find the agent's code +set_model(create_function_calling_agent(agent_conf)) + +# IMPORTANT: set this to False before logging the model to MLflow +debug = False + +if debug: + # logging.basicConfig(level=logging.INFO) + # print(find_config_folder_location()) + # print(os.path.abspath(os.getcwd())) + # mlflow.tracing.disable() + agent = create_function_calling_agent() + + vibe_check_query = { + "messages": [ + # {"role": "user", "content": f"what is agent evaluation?"}, + # {"role": "user", "content": f"How does the blender work?"}, + # { + # "role": "user", + # "content": f"find all docs from the section header 'Databricks documentation archive' or 'Work with files on Databricks'", + # }, + { + "role": "user", + "content": "Translate the sku `OLD-abs-1234` to the new format", + } + # { + # "role": "user", + # "content": f"convert sku 'OLD-XXX-1234' to the new format", + # }, + # { + # "role": "user", + # "content": f"what are recent customer issues? what words appeared most frequently?", + # }, + ] + } + + output = agent.predict(model_input=vibe_check_query) + print(output["content"]) \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/cookbook/agents/utils/__init__.py b/langgraph_agent_app_sample_code/cookbook/agents/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/agents/utils/chat.py b/langgraph_agent_app_sample_code/cookbook/agents/utils/chat.py new file mode 100644 index 0000000..a817c02 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/agents/utils/chat.py @@ -0,0 +1,145 @@ +import mlflow +from typing import Dict, List, Union +from dataclasses import asdict +import pandas as pd +from mlflow.models.rag_signatures import ChatCompletionRequest, Message + + +@mlflow.trace(span_type="PARSER") +def get_messages_array( + model_input: Union[ChatCompletionRequest, Dict, pd.DataFrame] +) -> List[Dict[str, str]]: + if type(model_input) == ChatCompletionRequest: + return model_input.messages + elif type(model_input) == dict: + return model_input.get("messages") + elif type(model_input) == pd.DataFrame: + return model_input.iloc[0].to_dict().get("messages") + + +@mlflow.trace(span_type="PARSER") +def extract_user_query_string(chat_messages_array: List[Dict[str, str]]) -> str: + """ + Extracts user query string from the chat messages array. + + Args: + chat_messages_array: Array of chat messages. + + Returns: + User query string. + """ + + if isinstance(chat_messages_array, pd.Series): + chat_messages_array = chat_messages_array.tolist() + + if isinstance(chat_messages_array[-1], dict): + return chat_messages_array[-1]["content"] + elif isinstance(chat_messages_array[-1], Message): + return chat_messages_array[-1].content + else: + return chat_messages_array[-1] + + +@mlflow.trace(span_type="PARSER") +def extract_chat_history( + chat_messages_array: List[Dict[str, str]] +) -> List[Dict[str, str]]: + """ + Extracts the chat history from the chat messages array. + + Args: + chat_messages_array: Array of chat messages. + + Returns: + The chat history. + """ + # Convert DataFrame to dict + if isinstance(chat_messages_array, pd.Series): + chat_messages_array = chat_messages_array.tolist() + + # Dictionary, return as is + if isinstance(chat_messages_array[0], dict): + return chat_messages_array[:-1] # return all messages except the last one + # MLflow Message, convert to Dictionary + elif isinstance(chat_messages_array[0], Message): + new_array = [] + for message in chat_messages_array[:-1]: + new_array.append(asdict(message)) + return new_array + else: + raise ValueError( + "chat_messages_array is not an Array of Dictionary, Pandas DataFrame, or array of MLflow Message." + ) + + +@mlflow.trace(span_type="PARSER") +def convert_messages_to_open_ai_format( + chat_messages_array: List[Dict[str, str]] +) -> List[Dict[str, str]]: + """ + Extracts the chat history from the chat messages array. + + Args: + chat_messages_array: Array of chat messages. + + Returns: + The chat history. + """ + # Convert DataFrame to dict + if isinstance(chat_messages_array, pd.Series): + chat_messages_array = chat_messages_array.tolist() + + # Dictionary, return as is + if isinstance(chat_messages_array[0], dict): + return chat_messages_array # return all messages except the last one + # MLflow Message, convert to Dictionary + elif isinstance(chat_messages_array[0], Message): + new_array = [] + for message in chat_messages_array: + new_array.append(asdict(message)) + return new_array + else: + raise ValueError( + "chat_messages_array is not an Array of Dictionary, Pandas DataFrame, or array of MLflow Message." + ) + + +@mlflow.trace(span_type="PARSER") +def concat_messages_array_to_string(messages): + concatenated_message = "\n".join( + [ + ( + f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" + if message.get("role") in ("assistant", "user") + else "" + ) + for message in messages + ] + ) + return concatenated_message + + +def remove_message_keys_with_null_values(message: Dict[str, str]) -> Dict[str, str]: + """ + Remove any keys with None/null values from the message. + Having a null value for a key breaks DBX model serving input validation even if that key is marked as optional in the schema, so we remove them. + Example: refusal key is set as None by OpenAI + """ + return {k: v for k, v in message.items() if v is not None} + + +@mlflow.trace(span_type="PARSER") +def remove_tool_calls_from_messages( + messages: List[Dict[str, str]] +) -> List[Dict[str, str]]: + modified_messages = messages.copy() + return [ + msg + for msg in modified_messages + if not ( + msg.get("role") == "tool" # Remove tool messages + or ( + msg.get("role") == "assistant" and "tool_calls" in msg + ) # Remove assistant messages with tool_calls + ) + ] diff --git a/langgraph_agent_app_sample_code/cookbook/agents/utils/execute_function.py b/langgraph_agent_app_sample_code/cookbook/agents/utils/execute_function.py new file mode 100644 index 0000000..1d0a7df --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/agents/utils/execute_function.py @@ -0,0 +1,8 @@ +import mlflow +import json + + +@mlflow.trace(span_type="FUNCTION") +def execute_function(tool, args): + result = tool(**args) + return json.dumps(result) diff --git a/langgraph_agent_app_sample_code/cookbook/agents/utils/load_config.py b/langgraph_agent_app_sample_code/cookbook/agents/utils/load_config.py new file mode 100644 index 0000000..7758eb8 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/agents/utils/load_config.py @@ -0,0 +1,30 @@ +import logging +import yaml +from pathlib import Path +import os +import os.path as path +from box import Box + +# Load the function calling Agent's configuration +fc_agent_config = Box(yaml.safe_load(Path("./configs/function_calling_agent_config.yaml").read_text())) +print(fc_agent_config) + +def load_config(): + try: + fc_agent_config_path = path.abspath(path.join(__file__ ,"../../../../configs/function_calling_agent_config.yaml")) + logging.info(f"Trying to load config from {fc_agent_config_path}") + print(fc_agent_config_path) + agent_conf = Box(yaml.safe_load(Path(fc_agent_config_path).read_text())) + return agent_conf + except FileNotFoundError as e: + return load_config_from_mlflow_model_config() + +def load_config_from_mlflow_model_config(): + try: + logging.info("Trying to load config from mlflow.models.ModelConfig()") + model_config_as_yaml = yaml.dump(mlflow.models.ModelConfig()._read_config()) + loaded_config = load_serializable_config_from_yaml(model_config_as_yaml) + logging.info(f"Loaded config from mlflow.models.ModelConfig(): {loaded_config}") + return loaded_config + except FileNotFoundError as e: + logging.info(f"Could not load config from mlflow.models.ModelConfig(): {e}") \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/cookbook/agents/utils/playground_parser.py b/langgraph_agent_app_sample_code/cookbook/agents/utils/playground_parser.py new file mode 100644 index 0000000..20800cc --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/agents/utils/playground_parser.py @@ -0,0 +1,98 @@ +import mlflow +from typing import List, Dict +import json + +## +# Utility functions for formatting OpenAI tool calls and responses for display in Databricks +# playground and review applications. These functions convert the raw message format into +# a more readable, XML-tagged format suitable for UI rendering. +## + + +@mlflow.trace(span_type="PARSER") +def convert_messages_to_playground_tool_display_strings( + messages: List[Dict[str, str]] +) -> str: + """Format a list of OpenAI chat messages for display in Databricks playground/review UI. + + Processes a sequence of OpenAI chat messages, with special handling for tool calls + and their responses. Tool-related content is wrapped in XML-like tags for proper + UI rendering and readability. + + Args: + messages (List[Dict[str, str]]): List of OpenAI message dictionaries containing role + (user/assistant/tool), content, and optional tool_calls from the chat completion API. + + Returns: + str: UI-friendly string with tool calls wrapped in tags and + tool responses wrapped in tags. + """ + output = "" + for msg in messages: # ignore first user input + if msg["role"] == "assistant" and msg.get("tool_calls"): # tool call + for tool_call in msg["tool_calls"]: + output += stringify_tool_call(tool_call) + # output += f"{json.dumps(msg, indent=2)}" + elif msg["role"] == "tool": # tool response + output += stringify_tool_result(msg) + # output += f"{json.dumps(msg, indent=2)}" + else: + output += msg["content"] if msg["content"] != None else "" + return output + + +@mlflow.trace(span_type="PARSER") +def stringify_tool_call(tool_call) -> str: + """Format an OpenAI tool call for display in Databricks playground/review UI. + + Extracts relevant information from an OpenAI tool call and formats it into a + UI-friendly string wrapped in XML-like tags for proper rendering. + + Args: + tool_call (dict): OpenAI tool call dictionary containing function details + (name, arguments) and call ID from the chat completion API. + + Returns: + str: UI-friendly string wrapped in tags, containing the + tool's name, ID, and arguments in a structured format. + """ + try: + function = tool_call["function"] + args_dict = json.loads(function["arguments"]) + request = { + "id": tool_call["id"], + "name": function["name"], + "arguments": json.dumps(args_dict), + } + + return f"{json.dumps(request)}" + + except Exception as e: + print("Failed to stringify tool call: ", e) + return str(tool_call) + + +@mlflow.trace(span_type="PARSER") +def stringify_tool_result(tool_msg) -> str: + """Format an OpenAI tool response for display in Databricks playground/review UI. + + Processes a tool's response message and formats it into a UI-friendly string + wrapped in XML-like tags for proper rendering. + + Args: + tool_msg (dict): OpenAI tool response dictionary containing the tool_call_id + and response content from the chat completion API. + + Returns: + str: UI-friendly string wrapped in tags, containing the + tool's response ID and content. + """ + try: + + result = json.dumps( + {"id": tool_msg["tool_call_id"], "content": tool_msg["content"]} + ) + return f"{result}" + except Exception as e: + print("Failed to stringify tool result:", e) + return str(tool_msg) diff --git a/langgraph_agent_app_sample_code/cookbook/agents/utils/signatures.py b/langgraph_agent_app_sample_code/cookbook/agents/utils/signatures.py new file mode 100644 index 0000000..4c0f5e7 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/agents/utils/signatures.py @@ -0,0 +1,49 @@ +from mlflow.types.schema import Array, ColSpec, DataType, Map, Object, Property, Schema + +# This is a custom version of the StringResponse class from Databricks Agents +# that includes the `messages` field. +# StringResponse: from mlflow.models.rag_signatures import StringResponse + +STRING_RESPONSE_WITH_MESSAGES = Schema( + [ + ColSpec(name="content", type=DataType.string), + ColSpec( + name="messages", + type=Array( + Object( + [ + Property("role", DataType.string), + Property("content", DataType.string, False), + Property("name", DataType.string, False), + Property("refusal", DataType.string, False), + Property( + "tool_calls", + Array( + Object( + [ + Property("id", DataType.string), + Property( + "function", + Object( + [ + Property("name", DataType.string), + Property( + "arguments", DataType.string + ), + ] + ), + ), + Property("type", DataType.string), + ] + ) + ), + False, + ), + Property("tool_call_id", DataType.string, False), + ] + ), + ), + required=False, + ), + ] +) diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/__init__.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/build_retriever_index.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/build_retriever_index.py new file mode 100644 index 0000000..e1e80c4 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/data_pipeline/build_retriever_index.py @@ -0,0 +1,123 @@ +from databricks.sdk.service.vectorsearch import ( + VectorSearchIndexesAPI, + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + PipelineType, + VectorIndexType, +) +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors.platform import ResourceDoesNotExist, BadRequest +import time +from cookbook.databricks_utils import get_table_url + + +# %md +# ##### `build_retriever_index` + +# `build_retriever_index` will build the vector search index which is used by our RAG to retrieve relevant documents. + +# Arguments: +# - `chunked_docs_table`: The chunked documents table. There is expected to be a `chunked_text` column, a `chunk_id` column, and a `url` column. +# - `primary_key`: The column to use for the vector index primary key. +# - `embedding_source_column`: The column to compute embeddings for in the vector index. +# - `vector_search_endpoint`: An optional vector search endpoint name. It not defined, defaults to the `{table_id}_vector_search`. +# - `vector_search_index_name`: An optional index name. If not defined, defaults to `{chunked_docs_table}_index`. +# - `embedding_endpoint_name`: An embedding endpoint name. +# - `force_delete_vector_search_endpoint`: Setting this to true will rebuild the vector search endpoint. + + +def build_retriever_index( + vector_search_endpoint: str, + chunked_docs_table_name: str, + vector_search_index_name: str, + embedding_endpoint_name: str, + force_delete_index_before_create=False, + primary_key: str = "chunk_id", # hard coded in the apply_chunking_fn + embedding_source_column: str = "content_chunked", # hard coded in the apply_chunking_fn +) -> tuple[bool, str]: + # Initialize workspace client and vector search API + w = WorkspaceClient() + vsc = w.vector_search_indexes + + def find_index(index_name): + try: + return vsc.get_index(index_name=index_name) + except ResourceDoesNotExist: + return None + + def wait_for_index_to_be_ready(index): + while not index.status.ready: + print( + f"Index {vector_search_index_name} exists, but is not ready, waiting 30 seconds..." + ) + time.sleep(30) + index = find_index(index_name=vector_search_index_name) + + def wait_for_index_to_be_deleted(index): + while index: + print( + f"Waiting for index {vector_search_index_name} to be deleted, waiting 30 seconds..." + ) + time.sleep(30) + index = find_index(index_name=vector_search_index_name) + + existing_index = find_index(index_name=vector_search_index_name) + if existing_index: + print(f"Found existing index {get_table_url(vector_search_index_name)}...") + if force_delete_index_before_create: + print(f"Deleting index {vector_search_index_name}...") + vsc.delete_index(index_name=vector_search_index_name) + wait_for_index_to_be_deleted(existing_index) + create_index = True + else: + wait_for_index_to_be_ready(existing_index) + create_index = False + print( + f"Starting the sync of index {vector_search_index_name}, this can take 15 minutes or much longer if you have a larger number of documents." + ) + # print(existing_index) + try: + vsc.sync_index(index_name=vector_search_index_name) + msg = f"Kicked off index sync for {vector_search_index_name}." + return (False, msg) + except BadRequest as e: + msg = f"Index sync already in progress, so failed to kick off index sync for {vector_search_index_name}. Please wait for the index to finish syncing and try again." + return (True, msg) + else: + print( + f'Creating new vector search index "{vector_search_index_name}" on endpoint "{vector_search_endpoint}"' + ) + create_index = True + + if create_index: + print( + "Computing document embeddings and Vector Search Index. This can take 15 minutes or much longer if you have a larger number of documents." + ) + try: + # Create delta sync index spec using the proper class + delta_sync_spec = DeltaSyncVectorIndexSpecRequest( + source_table=chunked_docs_table_name, + pipeline_type=PipelineType.TRIGGERED, + embedding_source_columns=[ + EmbeddingSourceColumn( + name=embedding_source_column, + embedding_model_endpoint_name=embedding_endpoint_name, + ) + ], + ) + + vsc.create_index( + name=vector_search_index_name, + endpoint_name=vector_search_endpoint, + primary_key=primary_key, + index_type=VectorIndexType.DELTA_SYNC, + delta_sync_index_spec=delta_sync_spec, + ) + msg = ( + f"Successfully created vector search index {vector_search_index_name}." + ) + print(msg) + return (False, msg) + except Exception as e: + msg = f"Vector search index creation failed. Wait 5 minutes and try running this cell again." + return (True, msg) diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/chunk_docs.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/chunk_docs.py new file mode 100644 index 0000000..793a721 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/data_pipeline/chunk_docs.py @@ -0,0 +1,44 @@ +from typing import Literal, Optional, Any, Callable +from databricks.vector_search.client import VectorSearchClient +from pyspark.sql.functions import explode +import pyspark.sql.functions as func +from typing import Callable +from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType +from pyspark.sql import DataFrame, SparkSession + + +def apply_chunking_fn( + parsed_docs_df: DataFrame, + chunking_fn: Callable[[str], list[str]], + propagate_columns: list[str], + doc_column: str = "content", +) -> DataFrame: + # imports here to avoid requiring these libraries in all notebooks since the data pipeline config imports this package + from langchain_text_splitters import RecursiveCharacterTextSplitter + from transformers import AutoTokenizer + import tiktoken + + print( + f"Applying chunking UDF to {parsed_docs_df.count()} documents using Spark - this may take a long time if you have many documents..." + ) + + parser_udf = func.udf( + chunking_fn, returnType=ArrayType(StringType()), useArrow=True + ) + chunked_array_docs = parsed_docs_df.withColumn( + "content_chunked", parser_udf(doc_column) + ) # .drop(doc_column) + chunked_docs = chunked_array_docs.select( + *propagate_columns, explode("content_chunked").alias("content_chunked") + ) + + # Add a primary key: "chunk_id". + chunks_with_ids = chunked_docs.withColumn( + "chunk_id", func.md5(func.col("content_chunked")) + ) + # Reorder for better display. + chunks_with_ids = chunks_with_ids.select( + "chunk_id", "content_chunked", *propagate_columns + ) + + return chunks_with_ids diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/default_parser.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/default_parser.py new file mode 100644 index 0000000..277fdc1 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/data_pipeline/default_parser.py @@ -0,0 +1,162 @@ +from typing import TypedDict +from datetime import datetime +import warnings +import traceback +import os +from urllib.parse import urlparse + +# PDF libraries +import fitz +import pymupdf4llm + +# HTML libraries +import markdownify +import re + +## DOCX libraries +import pypandoc +import tempfile + +## JSON libraries +import json + + +# Schema of the dict returned by `file_parser(...)` +# This is used to create the output Delta Table's schema. +# Adjust the class if you want to add additional columns from your parser, such as extracting custom metadata. +class ParserReturnValue(TypedDict): + # DO NOT CHANGE THESE NAMES + # Parsed content of the document + content: str # do not change this name + # The status of whether the parser succeeds or fails, used to exclude failed files downstream + parser_status: str # do not change this name + # Unique ID of the document + doc_uri: str # do not change this name + + # OK TO CHANGE THESE NAMES + # Optionally, you can add additional metadata fields here + # example_metadata: str + last_modified: datetime + + +# Parser function. Adjust this function to modify the parsing logic. +def file_parser( + raw_doc_contents_bytes: bytes, + doc_path: str, + modification_time: datetime, + doc_bytes_length: int, +) -> ParserReturnValue: + """ + Parses the content of a PDF document into a string. + + This function takes the raw bytes of a PDF document and its path, attempts to parse the document using PyPDF, + and returns the parsed content and the status of the parsing operation. + + Parameters: + - raw_doc_contents_bytes (bytes): The raw bytes of the document to be parsed (set by Spark when loading the file) + - doc_path (str): The DBFS path of the document, used to verify the file extension (set by Spark when loading the file) + - modification_time (timestamp): The last modification time of the document (set by Spark when loading the file) + - doc_bytes_length (long): The size of the document in bytes (set by Spark when loading the file) + + Returns: + - ParserReturnValue: A dictionary containing the parsed document content and the status of the parsing operation. + The 'contenty will contain the parsed text as a string, and the 'parser_status' key will indicate + whether the parsing was successful or if an error occurred. + """ + try: + from markdownify import markdownify as md + + filename, file_extension = os.path.splitext(doc_path) + + if file_extension == ".pdf": + pdf_doc = fitz.Document(stream=raw_doc_contents_bytes, filetype="pdf") + md_text = pymupdf4llm.to_markdown(pdf_doc) + + parsed_document = { + "content": md_text.strip(), + "parser_status": "SUCCESS", + } + elif file_extension == ".html": + html_content = raw_doc_contents_bytes.decode("utf-8") + + markdown_contents = md( + str(html_content).strip(), heading_style=markdownify.ATX + ) + markdown_stripped = re.sub(r"\n{3,}", "\n\n", markdown_contents.strip()) + + parsed_document = { + "content": markdown_stripped, + "parser_status": "SUCCESS", + } + elif file_extension == ".docx": + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + temp_file.write(raw_doc_contents_bytes) + temp_file_path = temp_file.name + md = pypandoc.convert_file(temp_file_path, "markdown", format="docx") + + parsed_document = { + "content": md.strip(), + "parser_status": "SUCCESS", + } + elif file_extension in [".txt", ".md"]: + parsed_document = { + "content": raw_doc_contents_bytes.decode("utf-8").strip(), + "parser_status": "SUCCESS", + } + elif file_extension in [".json", ".jsonl"]: + # NOTE: This is a placeholder for a JSON parser. It's not a "real" parser, it just returns the raw JSON formatted into XML-like strings that LLMs tend to like. + json_data = json.loads(raw_doc_contents_bytes.decode("utf-8")) + + def flatten_json_to_xml(obj, parent_key=""): + xml_parts = [] + if isinstance(obj, dict): + for key, value in obj.items(): + if isinstance(value, (dict, list)): + xml_parts.append(flatten_json_to_xml(value, key)) + else: + xml_parts.append(f"<{key}>{str(value)}") + elif isinstance(obj, list): + for i, item in enumerate(obj): + if isinstance(item, (dict, list)): + xml_parts.append( + flatten_json_to_xml(item, f"{parent_key}_{i}") + ) + else: + xml_parts.append( + f"<{parent_key}_{i}>{str(item)}" + ) + else: + xml_parts.append(f"<{parent_key}>{str(obj)}") + return "\n".join(xml_parts) + + flattened_content = flatten_json_to_xml(json_data) + parsed_document = { + "content": flattened_content.strip(), + "parser_status": "SUCCESS", + } + else: + raise Exception(f"No supported parser for {doc_path}") + + # Extract the required doc_uri + # convert from `dbfs:/Volumes/catalog/schema/pdf_docs/filename.pdf` to `/Volumes/catalog/schema/pdf_docs/filename.pdf` + modified_path = urlparse(doc_path).path + parsed_document["doc_uri"] = modified_path + + # Sample metadata extraction logic + # if "test" in parsed_document["content + # parsed_document["example_metadata"] = "test" + # else: + # parsed_document["example_metadata"] = "not test" + + # Add the modified time + parsed_document["last_modified"] = modification_time + + return parsed_document + + except Exception as e: + status = f"An error occurred: {e}\n{traceback.format_exc()}" + warnings.warn(status) + return { + "content": "", + "parser_status": f"ERROR: {status}", + } diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/parse_docs.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/parse_docs.py new file mode 100644 index 0000000..182de01 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/data_pipeline/parse_docs.py @@ -0,0 +1,159 @@ +import traceback +from datetime import datetime +from typing import Any, Callable, TypedDict, Dict +import os +from IPython.display import display_markdown +import warnings +import pyspark.sql.functions as func +from pyspark.sql.types import StructType +from pyspark.sql import DataFrame, SparkSession + + +def _parse_and_extract( + raw_doc_contents_bytes: bytes, + modification_time: datetime, + doc_bytes_length: int, + doc_path: str, + parse_file_udf: Callable[[[dict, Any]], str], +) -> Dict[str, Any]: + """Parses raw bytes & extract metadata.""" + try: + # Run the parser + parser_output_dict = parse_file_udf( + raw_doc_contents_bytes=raw_doc_contents_bytes, + doc_path=doc_path, + modification_time=modification_time, + doc_bytes_length=doc_bytes_length, + ) + + if parser_output_dict.get("parser_status") == "SUCCESS": + return parser_output_dict + else: + raise Exception(parser_output_dict.get("parser_status")) + + except Exception as e: + status = f"An error occurred: {e}\n{traceback.format_exc()}" + warnings.warn(status) + return { + "content": "", + "doc_uri": doc_path, + "parser_status": status, + } + + +def _get_parser_udf( + # extract_metadata_udf: Callable[[[dict, Any]], str], + parse_file_udf: Callable[[[dict, Any]], str], + spark_dataframe_schema: StructType, +): + """Gets the Spark UDF which will parse the files in parallel. + + Arguments: + - extract_metadata_udf: A function that takes parsed content and extracts the metadata + - parse_file_udf: A function that takes the raw file and returns the parsed text. + - spark_dataframe_schema: The resulting schema of the document delta table + """ + # This UDF will load each file, parse the doc, and extract metadata. + parser_udf = func.udf( + lambda raw_doc_contents_bytes, modification_time, doc_bytes_length, doc_path: _parse_and_extract( + raw_doc_contents_bytes, + modification_time, + doc_bytes_length, + doc_path, + parse_file_udf, + ), + returnType=spark_dataframe_schema, + useArrow=True, + ) + return parser_udf + + +def load_files_to_df(spark: SparkSession, source_path: str) -> DataFrame: + """ + Load files from a directory into a Spark DataFrame. + Each row in the DataFrame will contain the path, length, and content of the file; for more + details, see https://spark.apache.org/docs/latest/sql-data-sources-binaryFile.html + """ + + print(f"Loading the raw files from {source_path}...") + # Load the raw riles + raw_files_df = ( + spark.read.format("binaryFile") + .option("recursiveFileLookup", "true") + .load(source_path) + ) + + # Check that files were present and loaded + if raw_files_df.count() == 0: + raise Exception(f"`{source_path}` does not contain any files.") + + # display_markdown( + # f"### Found {raw_files_df.count()} files in {source_path}: ", raw=True + # ) + # raw_files_df.display() + return raw_files_df + + +def apply_parsing_fn( + raw_files_df: DataFrame, + parse_file_fn: Callable[[[dict, Any]], str], + parsed_df_schema: StructType, +) -> DataFrame: + """ + Apply a file-parsing UDF to a DataFrame whose rows correspond to file content/metadata loaded via + https://spark.apache.org/docs/latest/sql-data-sources-binaryFile.html + Returns a DataFrame with the parsed content and metadata. + """ + print( + f"Applying parsing & metadata extraction to {raw_files_df.count()} files using Spark - this may take a long time if you have many documents..." + ) + + parser_udf = _get_parser_udf(parse_file_fn, parsed_df_schema) + + # Run the parsing + parsed_files_staging_df = raw_files_df.withColumn( + "parsing", parser_udf("content", "modificationTime", "length", "path") + ).drop("content") + + # Filter for successfully parsed files + parsed_files_df = parsed_files_staging_df # .filter( + # parsed_files_staging_df.parsing.parser_status == "SUCCESS" + # ) + + # Change the schema to the resulting schema + resulting_fields = [field.name for field in parsed_df_schema.fields] + + parsed_files_df = parsed_files_df.select( + *[func.col(f"parsing.{field}").alias(field) for field in resulting_fields] + ) + return parsed_files_df + + +def check_parsed_df_for_errors(parsed_files_df) -> tuple[bool, str, DataFrame]: + # Check and warn on any errors + errors_df = parsed_files_df.filter(func.col(f"parser_status") != "SUCCESS") + + num_errors = errors_df.count() + if num_errors > 0: + msg = f"{num_errors} documents ({round(errors_df.count()/parsed_files_df.count(), 2)*100}) of documents had parse errors. Please review." + return (True, msg, errors_df) + else: + msg = "All documents were parsed." + print(msg) + return (False, msg, None) + + +def check_parsed_df_for_empty_parsed_files(parsed_files_df): + # Check and warn on any errors + num_empty_df = parsed_files_df.filter( + func.col(f"parser_status") == "SUCCESS" + ).filter(func.col("content") == "") + + num_errors = num_empty_df.count() + if num_errors > 0: + msg = f"{num_errors} documents ({round(num_empty_df.count()/parsed_files_df.count(), 2)*100}) of documents returned empty parsing results. Please review." + return (True, msg, num_empty_df) + else: + msg = "All documents produced non-null parsing results." + print(msg) + return (False, msg, None) diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/recursive_character_text_splitter.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/recursive_character_text_splitter.py new file mode 100644 index 0000000..d9f6ed8 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/data_pipeline/recursive_character_text_splitter.py @@ -0,0 +1,255 @@ +from typing import Callable, Tuple, Optional +from databricks.sdk import WorkspaceClient +from pydantic import BaseModel + +# %md +# ##### `get_recursive_character_text_splitter` + +# `get_recursive_character_text_splitter` creates a new function that, given an embedding endpoint, returns a callable that can chunk text documents. This utility allows you to write the core business logic of the chunker, without dealing with the details of text splitting. You can decide to write your own, or edit this code if it does not fit your use case. + +# **Arguments:** + +# - `model_serving_endpoint`: The name of the Model Serving endpoint with the embedding model. +# - `embedding_model_name`: The name of the embedding model e.g., `gte-large-en-v1.5`, etc. If `model_serving_endpoint` is an OpenAI External Model or FMAPI model and set to `None`, this will be automatically detected. +# - `chunk_size_tokens`: An optional size for each chunk in tokens. Defaults to `None`, which uses the model's entire context window. +# - `chunk_overlap_tokens`: Tokens that should overlap between chunks. Defaults to `0`. + +# **Returns:** A callable that takes a document (`str`) and produces a list of chunks (`list[str]`). + +# Constants +HF_CACHE_DIR = "/tmp/hf_cache/" + +# Embedding Models Configuration +EMBEDDING_MODELS = { + "gte-large-en-v1.5": { + # "tokenizer": lambda: AutoTokenizer.from_pretrained( + # "Alibaba-NLP/gte-large-en-v1.5", cache_dir=HF_CACHE_DIR + # ), + "context_window": 8192, + "type": "SENTENCE_TRANSFORMER", + }, + "bge-large-en-v1.5": { + # "tokenizer": lambda: AutoTokenizer.from_pretrained( + # "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR + # ), + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "bge_large_en_v1_5": { + # "tokenizer": lambda: AutoTokenizer.from_pretrained( + # "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR + # ), + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "text-embedding-ada-002": { + "context_window": 8192, + # "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-ada-002"), + "type": "OPENAI", + }, + "text-embedding-3-small": { + "context_window": 8192, + # "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-small"), + "type": "OPENAI", + }, + "text-embedding-3-large": { + "context_window": 8192, + # "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-large"), + "type": "OPENAI", + }, +} + + +def get_workspace_client() -> WorkspaceClient: + """Returns a WorkspaceClient instance.""" + return WorkspaceClient() + + +# TODO: this is a cheap hack to avoid importing tokenizer libs at the top level - the datapipeline utils are imported by the agent notebook which won't have these libs loaded & we don't want to since autotokenizer is heavy weight. +def get_embedding_model_tokenizer(endpoint_type: str) -> Optional[dict]: + from transformers import AutoTokenizer + import tiktoken + + # copy here to prevent needing to install tokenizer libraries everywhere this is imported + EMBEDDING_MODELS_W_TOKENIZER = { + "gte-large-en-v1.5": { + "tokenizer": lambda: AutoTokenizer.from_pretrained( + "Alibaba-NLP/gte-large-en-v1.5", cache_dir=HF_CACHE_DIR + ), + "context_window": 8192, + "type": "SENTENCE_TRANSFORMER", + }, + "bge-large-en-v1.5": { + "tokenizer": lambda: AutoTokenizer.from_pretrained( + "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR + ), + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "bge_large_en_v1_5": { + "tokenizer": lambda: AutoTokenizer.from_pretrained( + "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR + ), + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "text-embedding-ada-002": { + "context_window": 8192, + "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-ada-002"), + "type": "OPENAI", + }, + "text-embedding-3-small": { + "context_window": 8192, + "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-small"), + "type": "OPENAI", + }, + "text-embedding-3-large": { + "context_window": 8192, + "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-large"), + "type": "OPENAI", + }, + } + return EMBEDDING_MODELS_W_TOKENIZER.get(endpoint_type).get("tokenizer") + + +def get_embedding_model_config(endpoint_type: str) -> Optional[dict]: + """ + Retrieve embedding model configuration by endpoint type. + """ + + return EMBEDDING_MODELS.get(endpoint_type) + + +def extract_endpoint_type(llm_endpoint) -> Optional[str]: + """ + Extract the endpoint type from the given llm_endpoint object. + """ + try: + return llm_endpoint.config.served_entities[0].external_model.name + except AttributeError: + try: + return llm_endpoint.config.served_entities[0].foundation_model.name + except AttributeError: + return None + + +def detect_fmapi_embedding_model_type( + model_serving_endpoint: str, +) -> Tuple[Optional[str], Optional[dict]]: + """ + Detects the embedding model type and configuration for the given endpoint. + Returns a tuple of (endpoint_type, embedding_config) or (None, None) if not found. + """ + client = get_workspace_client() + + try: + llm_endpoint = client.serving_endpoints.get(name=model_serving_endpoint) + endpoint_type = extract_endpoint_type(llm_endpoint) + except Exception as e: + endpoint_type = None + + embedding_config = ( + get_embedding_model_config(endpoint_type) if endpoint_type else None + ) + + embedding_config["tokenizer"] = ( + get_embedding_model_tokenizer(endpoint_type) if endpoint_type else None + ) + + return (endpoint_type, embedding_config) + + +def validate_chunk_size(chunk_spec: dict): + """ + Validate the chunk size and overlap settings in chunk_spec. + Raises ValueError if any condition is violated. + """ + if ( + chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"] + ) > chunk_spec["context_window"]: + msg = ( + f'Proposed chunk_size of {chunk_spec["chunk_size_tokens"]} + overlap of {chunk_spec["chunk_overlap_tokens"]} ' + f'is {chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]} which is greater than context ' + f'window of {chunk_spec["context_window"]} tokens.', + ) + return (False, msg) + elif chunk_spec["chunk_overlap_tokens"] > chunk_spec["chunk_size_tokens"]: + msg = ( + f'Proposed `chunk_overlap_tokens` of {chunk_spec["chunk_overlap_tokens"]} is greater than the ' + f'`chunk_size_tokens` of {chunk_spec["chunk_size_tokens"]}. Reduce the size of `chunk_size_tokens`.', + ) + return (False, msg) + else: + context_usage = ( + round( + (chunk_spec["chunk_size_tokens"] + chunk_spec["chunk_overlap_tokens"]) + / chunk_spec["context_window"], + 2, + ) + * 100 + ) + msg = f'Chunk size in tokens: {chunk_spec["chunk_size_tokens"]} and chunk overlap in tokens: {chunk_spec["chunk_overlap_tokens"]} are valid. Using {round(context_usage, 2)}% ({chunk_spec["chunk_size_tokens"] + chunk_spec["chunk_overlap_tokens"]} tokens) of the {chunk_spec["context_window"]} token context window.' + return (True, msg) + + +def get_recursive_character_text_splitter( + model_serving_endpoint: str, + embedding_model_name: str = None, + chunk_size_tokens: int = None, + chunk_overlap_tokens: int = 0, +) -> Callable[[str], list[str]]: + # imports here to prevent needing to install everywhere + + from langchain_text_splitters import RecursiveCharacterTextSplitter + from transformers import AutoTokenizer + import tiktoken + + try: + # Detect the embedding model and its configuration + embedding_model_name, chunk_spec = detect_fmapi_embedding_model_type( + model_serving_endpoint + ) + + if chunk_spec is None or embedding_model_name is None: + # Fall back to using provided embedding_model_name + chunk_spec = EMBEDDING_MODELS.get(embedding_model_name) + if chunk_spec is None: + raise KeyError + + # Update chunk specification based on provided parameters + chunk_spec["chunk_size_tokens"] = ( + chunk_size_tokens or chunk_spec["context_window"] + ) + chunk_spec["chunk_overlap_tokens"] = chunk_overlap_tokens + + # Validate chunk size and overlap + is_valid, msg = validate_chunk_size(chunk_spec) + if not is_valid: + raise ValueError(msg) + else: + print(msg) + + except KeyError: + raise ValueError( + f"Embedding model `{embedding_model_name}` not found. Available models: {EMBEDDING_MODELS.keys()}" + ) + + def _recursive_character_text_splitter(text: str) -> list[str]: + tokenizer = chunk_spec["tokenizer"]() + if chunk_spec["type"] == "SENTENCE_TRANSFORMER": + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer, + chunk_size=chunk_spec["chunk_size_tokens"], + chunk_overlap=chunk_spec["chunk_overlap_tokens"], + ) + elif chunk_spec["type"] == "OPENAI": + splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + tokenizer.name, + chunk_size=chunk_spec["chunk_size_tokens"], + chunk_overlap=chunk_spec["chunk_overlap_tokens"], + ) + else: + raise ValueError(f"Unsupported model type: {chunk_spec['type']}") + return splitter.split_text(text) + + return _recursive_character_text_splitter diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/utils/__init__.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/data_pipeline/utils/typed_dicts_to_spark_schema.py b/langgraph_agent_app_sample_code/cookbook/data_pipeline/utils/typed_dicts_to_spark_schema.py new file mode 100644 index 0000000..195c16e --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/data_pipeline/utils/typed_dicts_to_spark_schema.py @@ -0,0 +1,103 @@ +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + IntegerType, + DoubleType, + BooleanType, + ArrayType, + TimestampType, + DateType, +) +from typing import TypedDict, get_type_hints, List +from datetime import datetime, date, time + + +def typed_dict_to_spark_fields(typed_dict: type[TypedDict]) -> StructType: + """ + Converts a TypedDict into a list of Spark StructField objects. + + This function maps Python types defined in a TypedDict to their corresponding + Spark SQL data types, facilitating the creation of a Spark DataFrame schema + from Python type annotations. + + Parameters: + - typed_dict (type[TypedDict]): The TypedDict class to be converted. + + Returns: + - StructType: A list of StructField objects representing the Spark schema. + + Raises: + - ValueError: If an unsupported type is encountered or if dictionary types are used. + """ + + # Mapping of type names to Spark type objects + type_mapping = { + str: StringType(), + int: IntegerType(), + float: DoubleType(), + bool: BooleanType(), + list: ArrayType(StringType()), # Default to StringType for arrays + datetime: TimestampType(), + date: DateType(), + } + + def get_spark_type(value_type): + """ + Helper function to map a Python type to a Spark SQL data type. + + This function supports basic Python types, lists of a single type, and raises + an error for unsupported types or dictionaries. + + Parameters: + - value_type: The Python type to be converted. + + Returns: + - DataType: The corresponding Spark SQL data type. + + Raises: + - ValueError: If the type is unsupported or if dictionary types are used. + """ + if value_type in type_mapping: + return type_mapping[value_type] + elif hasattr(value_type, "__origin__") and value_type.__origin__ == list: + # Handle List[type] types + return ArrayType(get_spark_type(value_type.__args__[0])) + elif hasattr(value_type, "__origin__") and value_type.__origin__ == dict: + # Handle Dict[type, type] types (not fully supported) + raise ValueError("Dict types are not fully supported") + else: + raise ValueError(f"Unsupported type: {value_type}") + + # Get the type hints for the TypedDict + type_hints = get_type_hints(typed_dict) + + # Convert the type hints into a list of StructField objects + fields = [ + StructField(key, get_spark_type(value), True) + for key, value in type_hints.items() + ] + + # Create and return the StructType object + return fields + + +def typed_dicts_to_spark_schema(*typed_dicts: type[TypedDict]) -> StructType: + """ + Converts multiple TypedDicts into a Spark schema. + + This function allows for the combination of multiple TypedDicts into a single + Spark DataFrame schema, enabling the creation of complex data structures. + + Parameters: + - *typed_dicts: Variable number of TypedDict classes to be converted. + + Returns: + - StructType: A Spark schema represented as a StructType object, which is a collection + of StructField objects derived from the provided TypedDicts. + """ + fields = [] + for typed_dict in typed_dicts: + fields.extend(typed_dict_to_spark_fields(typed_dict)) + + return StructType(fields) diff --git a/langgraph_agent_app_sample_code/cookbook/databricks_utils/__init__.py b/langgraph_agent_app_sample_code/cookbook/databricks_utils/__init__.py new file mode 100644 index 0000000..94fd8fb --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/databricks_utils/__init__.py @@ -0,0 +1,225 @@ +# Helper functions for displaying Delta Table and Volume URLs + +from typing import Optional +import json +import subprocess + +from databricks.sdk import WorkspaceClient +from mlflow.utils import databricks_utils as du + + +def get_databricks_cli_config() -> dict: + """Retrieve the Databricks CLI configuration by running 'databricks auth describe' command. + + Returns: + dict: The parsed JSON configuration from the Databricks CLI, or None if an error occurs + + Note: + Requires the Databricks CLI to be installed and configured + """ + try: + # Run databricks auth describe command and capture output + process = subprocess.run( + ["databricks", "auth", "describe", "-o", "json"], + capture_output=True, + text=True, + check=True, # Raises CalledProcessError if command fails + ) + + # Parse JSON output + return json.loads(process.stdout) + except subprocess.CalledProcessError as e: + print(f"Error running databricks CLI command: {e}") + return None + except json.JSONDecodeError as e: + print(f"Error parsing databricks CLI JSON output: {e}") + return None + except Exception as e: + print(f"Unexpected error getting databricks config from CLI: {e}") + return None + + +def get_workspace_hostname() -> str: + """Get the Databricks workspace hostname. + + Returns: + str: The full workspace hostname (e.g., 'https://my-workspace.cloud.databricks.com') + + Raises: + RuntimeError: If not in a Databricks notebook and unable to get workspace hostname from CLI config + """ + if du.is_in_databricks_notebook(): + return "https://" + du.get_browser_hostname() + else: + cli_config = get_databricks_cli_config() + if cli_config is None: + raise RuntimeError("Could not get Databricks CLI config") + try: + return cli_config["details"]["host"] + except KeyError: + raise RuntimeError( + "Could not find workspace hostname in Databricks CLI config" + ) + + +def get_table_url(table_fqdn: str) -> str: + """Generate the URL for a Unity Catalog table in the Databricks UI. + + Args: + table_fqdn: Fully qualified table name in format 'catalog.schema.table'. + Can optionally include backticks around identifiers. + + Returns: + str: The full URL to view the table in the Databricks UI. + + Example: + >>> get_table_url("main.default.my_table") + 'https://my-workspace.cloud.databricks.com/explore/data/main/default/my_table' + """ + table_fqdn = table_fqdn.replace("`", "") + catalog, schema, table = table_fqdn.split(".") + browser_url = get_workspace_hostname() + url = f"{browser_url}/explore/data/{catalog}/{schema}/{table}" + return url + + +def get_volume_url(volume_fqdn: str) -> str: + """Generate the URL for a Unity Catalog volume in the Databricks UI. + + Args: + volume_fqdn: Fully qualified volume name in format 'catalog.schema.volume'. + Can optionally include backticks around identifiers. + + Returns: + str: The full URL to view the volume in the Databricks UI. + + Example: + >>> get_volume_url("main.default.my_volume") + 'https://my-workspace.cloud.databricks.com/explore/data/volumes/main/default/my_volume' + """ + volume_fqdn = volume_fqdn.replace("`", "") + catalog, schema, volume = volume_fqdn.split(".") + browser_url = get_workspace_hostname() + url = f"{browser_url}/explore/data/volumes/{catalog}/{schema}/{volume}" + return url + + +def get_mlflow_experiment_url(experiment_id: str) -> str: + """Generate the URL for an MLflow experiment in the Databricks UI. + + Args: + experiment_id: The ID of the MLflow experiment + + Returns: + str: The full URL to view the MLflow experiment in the Databricks UI. + + Example: + >>> get_mlflow_experiment_url("") + 'https://my-workspace.cloud.databricks.com/ml/experiments/' + """ + browser_url = get_workspace_hostname() + url = f"{browser_url}/ml/experiments/{experiment_id}" + return url + + +def get_mlflow_experiment_traces_url(experiment_id: str) -> str: + """Generate the URL for the MLflow experiment traces in the Databricks UI.""" + return get_mlflow_experiment_url(experiment_id) + "?compareRunsMode=TRACES" + + +def get_function_url(function_fqdn: str) -> str: + """Generate the URL for a Unity Catalog function in the Databricks UI. + + Args: + function_fqdn: Fully qualified function name in format 'catalog.schema.function'. + Can optionally include backticks around identifiers. + + Returns: + str: The full URL to view the function in the Databricks UI. + + Example: + >>> get_function_url("main.default.my_function") + 'https://my-workspace.cloud.databricks.com/explore/data/functions/main/default/my_function' + """ + function_fqdn = function_fqdn.replace("`", "") + catalog, schema, function = function_fqdn.split(".") + browser_url = get_workspace_hostname() + url = f"{browser_url}/explore/data/functions/{catalog}/{schema}/{function}" + return url + + +def get_cluster_url(cluster_id: str) -> str: + """Generate the URL for a Databricks cluster in the Databricks UI. + + Args: + cluster_id: The ID of the cluster + + Returns: + str: The full URL to view the cluster in the Databricks UI. + + Example: + >>> get_cluster_url("") + 'https://my-workspace.cloud.databricks.com/compute/clusters/' + """ + browser_url = get_workspace_hostname() + url = f"{browser_url}/compute/clusters/{cluster_id}" + return url + + +def get_active_cluster_id_from_databricks_auth() -> Optional[str]: + """Get the active cluster ID from the Databricks CLI authentication configuration. + + Returns: + Optional[str]: The active cluster ID if found, None if not found or if an error occurs + + Note: + This function relies on the Databricks CLI configuration having a cluster_id set + """ + if du.is_in_databricks_notebook(): + raise ValueError( + "Cannot get active cluster ID from the Databricks CLI in a Databricks notebook" + ) + try: + # Get config from the databricks cli + auth_output = get_databricks_cli_config() + + # Safely navigate nested dict + details = auth_output.get("details", {}) + config = details.get("configuration", {}) + cluster = config.get("cluster_id", {}) + cluster_id = cluster.get("value") + + if cluster_id is None: + raise ValueError("Could not find cluster_id in Databricks auth config") + + return cluster_id + + except Exception as e: + print(f"Unexpected error: {e}") + return None + + +def get_active_cluster_id() -> Optional[str]: + """Get the active cluster ID. + + Returns: + Optional[str]: The active cluster ID if found, None if not found or if an error occurs + """ + if du.is_in_databricks_notebook(): + return du.get_active_cluster_id() + else: + return get_active_cluster_id_from_databricks_auth() + + +def get_current_user_info(spark) -> tuple[str, str, str]: + # Get current user's name & email + w = WorkspaceClient() + user_email = w.current_user.me().user_name + user_name = user_email.split("@")[0].replace(".", "_") + + # Get the workspace default UC catalog + default_catalog = spark.sql("select current_catalog() as cur_catalog").collect()[0][ + "cur_catalog" + ] + + return user_email, user_name, default_catalog diff --git a/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_evaluation/__init__.py b/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_evaluation/evaluation_set.py b/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_evaluation/evaluation_set.py new file mode 100644 index 0000000..6cd2e84 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_evaluation/evaluation_set.py @@ -0,0 +1,236 @@ +from typing import List, Mapping, Optional + +import mlflow.entities as mlflow_entities + +from pyspark import sql +from pyspark.sql import functions as F, types as T +from pyspark.sql.window import Window + +from databricks.rag_eval.evaluation import traces + +# Deduplicate the assessment log + +# By default, the assessment log contains one row for every action/click the user does in the Review App. This code translates these logs into a single row for each request. + +_REQUEST_ID = "request_id" +_TIMESTAMP = "timestamp" +_ROW_NUMBER = "row_number" +_SOURCE = "source" +_SOURCE_ID = "source.id" +_STEP_ID = "step_id" +_TEXT_ASSESSMENT = "text_assessment" +_RETRIEVAL_ASSESSMENT = "retrieval_assessment" + + +def _dedup_by_assessment_window( + assessment_log_df: sql.DataFrame, window: Window +) -> sql.DataFrame: + """ + Dedup the assessment logs by taking the first row from each group, defined by the window + :param assessment_log_df: Pyspark DataFrame of the assessment logs + :param window: Pyspark window to group assessments by + :return: Pyspark DataFrame of the deduped assessment logs + """ + return ( + assessment_log_df.withColumn(_ROW_NUMBER, F.row_number().over(window)) + .filter(F.col(_ROW_NUMBER) == 1) + .drop(_ROW_NUMBER) + ) + + +def _dedup_assessment_log(assessment_log_df: sql.DataFrame) -> sql.DataFrame: + """ + Dedup the assessment logs to get the latest assessments. + :param assessment_log_df: Pyspark DataFrame of the assessment logs + :return: Pyspark DataFrame of the deduped assessment logs + """ + # Dedup the text assessments + text_assessment_window = Window.partitionBy(_REQUEST_ID, _SOURCE_ID).orderBy( + F.col(_TIMESTAMP).desc() + ) + deduped_text_assessment_df = _dedup_by_assessment_window( + # Filter rows with null text assessments + assessment_log_df.filter(F.col(_TEXT_ASSESSMENT).isNotNull()), + text_assessment_window, + ) + + # Dedup the retrieval assessments + retrieval_assessment_window = Window.partitionBy( + _REQUEST_ID, + _SOURCE_ID, + f"{_RETRIEVAL_ASSESSMENT}.position", + f"{_RETRIEVAL_ASSESSMENT}.{_STEP_ID}", + ).orderBy(F.col(_TIMESTAMP).desc()) + deduped_retrieval_assessment_df = _dedup_by_assessment_window( + # Filter rows with null retrieval assessments + assessment_log_df.filter(F.col(_RETRIEVAL_ASSESSMENT).isNotNull()), + retrieval_assessment_window, + ) + + # Collect retrieval assessments from the same request/step/source into a single list + nested_retrieval_assessment_df = ( + deduped_retrieval_assessment_df.groupBy(_REQUEST_ID, _SOURCE_ID, _STEP_ID).agg( + F.any_value(_TIMESTAMP).alias(_TIMESTAMP), + F.any_value(_SOURCE).alias(_SOURCE), + F.collect_list(_RETRIEVAL_ASSESSMENT).alias("retrieval_assessments"), + ) + # Drop the old retrieval assessment, source id, and text assessment columns + .drop(_RETRIEVAL_ASSESSMENT, "id", _TEXT_ASSESSMENT) + ) + + # Join the deduped text assessments with the nested deduped retrieval assessments + deduped_assessment_log_df = deduped_text_assessment_df.alias("a").join( + nested_retrieval_assessment_df.alias("b"), + (F.col(f"a.{_REQUEST_ID}") == F.col(f"b.{_REQUEST_ID}")) + & (F.col(f"a.{_SOURCE_ID}") == F.col(f"b.{_SOURCE_ID}")), + "full_outer", + ) + + # Coalesce columns from both dataframes in case a request does not have either assessment + return deduped_assessment_log_df.select( + F.coalesce(F.col(f"a.{_REQUEST_ID}"), F.col(f"b.{_REQUEST_ID}")).alias( + _REQUEST_ID + ), + F.coalesce(F.col(f"a.{_STEP_ID}"), F.col(f"b.{_STEP_ID}")).alias(_STEP_ID), + F.coalesce(F.col(f"a.{_TIMESTAMP}"), F.col(f"b.{_TIMESTAMP}")).alias( + _TIMESTAMP + ), + F.coalesce(F.col(f"a.{_SOURCE}"), F.col(f"b.{_SOURCE}")).alias(_SOURCE), + F.col(f"a.{_TEXT_ASSESSMENT}").alias(_TEXT_ASSESSMENT), + F.col("b.retrieval_assessments").alias(_RETRIEVAL_ASSESSMENT), + # F.col("schema_version") + ) + + ## Attach ground truth + + +def attach_ground_truth(request_log_df, deduped_assessment_log_df): + suggested_output_col = F.col(f"{_TEXT_ASSESSMENT}.suggested_output") + is_correct_col = F.col(f"{_TEXT_ASSESSMENT}.ratings.answer_correct.value") + # Extract out the thumbs up/down rating and the suggested output + rating_log_df = ( + deduped_assessment_log_df.withColumn("is_correct", is_correct_col) + .withColumn( + "suggested_output", + F.when(suggested_output_col == "", None).otherwise(suggested_output_col), + ) + .withColumn("source_user", F.col("source.id")) + .select( + "request_id", + "is_correct", + "suggested_output", + "source_user", + _RETRIEVAL_ASSESSMENT, + ) + ) + # Join the request log with the ratings from above + raw_requests_with_feedback_df = request_log_df.join( + rating_log_df, + request_log_df.databricks_request_id == rating_log_df.request_id, + "left", + ) + + raw_requests_with_feedback_df = raw_requests_with_feedback_df.drop("request_id") + return raw_requests_with_feedback_df + +_EXPECTED_RETRIEVAL_CONTEXT_SCHEMA = T.ArrayType( + T.StructType( + [ + T.StructField("doc_uri", T.StringType()), + T.StructField("content", T.StringType()), + ] + ) +) + + +def extract_retrieved_chunks_from_trace(trace_str: str) -> List[Mapping[str, str]]: + """Helper to extract the retrieved chunks from a trace string""" + trace = mlflow_entities.Trace.from_json(trace_str) + chunks = traces.extract_retrieval_context_from_trace(trace) + return [{"doc_uri": chunk.doc_uri, "content": chunk.content} for chunk in chunks] + + +@F.udf(_EXPECTED_RETRIEVAL_CONTEXT_SCHEMA) +def construct_expected_retrieval_context( + trace_str: Optional[str], chunk_at_i_relevance: Optional[List[str]] +) -> Optional[List[Mapping[str, str]]]: + """Helper to construct the expected retrieval context. Any retrieved chunks that are not relevant are dropped.""" + if chunk_at_i_relevance is None or trace_str is None: + return None + retrieved_chunks = extract_retrieved_chunks_from_trace(trace_str) + expected_retrieval_context = [ + chunk + for chunk, rating in zip(retrieved_chunks, chunk_at_i_relevance) + if rating == "true" + ] + return expected_retrieval_context if len(expected_retrieval_context) else None + + +# ================================= + + +def identify_potential_eval_set_records(raw_requests_with_feedback_df): + # For thumbs up, use either the suggested output or the response, in that order + positive_feedback_df = ( + raw_requests_with_feedback_df.where(F.col("is_correct") == F.lit("positive")) + .withColumn( + "expected_response", + F.when( + F.col("suggested_output") != None, F.col("suggested_output") + ).otherwise(F.col("response")), + ) + .withColumn("source_tag", F.lit("thumbs_up")) + ) + + # For thumbs down, use the suggested output if there is one + negative_feedback_df = ( + raw_requests_with_feedback_df.where(F.col("is_correct") == F.lit("negative")) + .withColumn("expected_response", F.col("suggested_output")) + .withColumn("source_tag", F.lit("thumbs_down_edited")) + ) + + # For no feedback or IDK, there is no expected response. + no_or_unknown_feedback_df = ( + raw_requests_with_feedback_df.where( + (F.col("is_correct").isNull()) + | ( + (F.col("is_correct") != F.lit("negative")) + & (F.col("is_correct") != F.lit("positive")) + ) + ) + .withColumn("expected_response", F.lit(None)) + .withColumn("source_tag", F.lit("no_feedback_provided")) + ) + # Join the above feedback tables and select the relevant columns for the eval harness + requests_with_feedback_df = positive_feedback_df.unionByName( + negative_feedback_df + ).unionByName(no_or_unknown_feedback_df) + # Get the thumbs up/down for each retrieved chunk + requests_with_feedback_df = requests_with_feedback_df.withColumn( + "chunk_at_i_relevance", + F.transform( + F.col(_RETRIEVAL_ASSESSMENT), lambda x: x.ratings.answer_correct.value + ), + ).drop(_RETRIEVAL_ASSESSMENT) + + requests_with_feedback_df = requests_with_feedback_df.withColumnRenamed( + "databricks_request_id", "request_id" + ) + + # Add the expected retrieved context column + requests_with_feedback_df = requests_with_feedback_df.withColumn( + "expected_retrieved_context", + construct_expected_retrieval_context( + F.col("trace"), F.col("chunk_at_i_relevance") + ), + ) + return requests_with_feedback_df + +def create_potential_evaluation_set(request_log_df, assessment_log_df): + raw_requests_with_feedback_df = attach_ground_truth( + request_log_df, assessment_log_df + ) + requests_with_feedback_df = identify_potential_eval_set_records( + raw_requests_with_feedback_df + ) + return requests_with_feedback_df \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_framework/__init__.py b/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_framework/get_inference_tables.py b/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_framework/get_inference_tables.py new file mode 100644 index 0000000..1d1183c --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/databricks_utils/agent_framework/get_inference_tables.py @@ -0,0 +1,35 @@ +from databricks.sdk import WorkspaceClient +from databricks import agents + +def get_inference_tables(uc_model_fqn): + w = WorkspaceClient() + + deployment = agents.get_deployments(uc_model_fqn) + if len(deployment) == 0: + raise ValueError(f"No deployments found for model {uc_model_fqn}") + endpoint = w.serving_endpoints.get(deployment[0].endpoint_name) + + + try: + endpoint_config = endpoint.config.auto_capture_config + except AttributeError as e: + endpoint_config = endpoint.pending_config.auto_capture_config + + inference_table_name = endpoint_config.state.payload_table.name + inference_table_catalog = endpoint_config.catalog_name + inference_table_schema = endpoint_config.schema_name + + # Cleanly formatted tables + assessment_log_table_name = f"{inference_table_name}_assessment_logs" + request_log_table_name = f"{inference_table_name}_request_logs" + + return { + 'uc_catalog_name': inference_table_catalog, + 'uc_schema_name': inference_table_schema, + 'table_names': { + 'raw_payload_logs': inference_table_name, + 'assessment_logs': assessment_log_table_name, + 'request_logs': request_log_table_name, + } + + } diff --git a/langgraph_agent_app_sample_code/cookbook/databricks_utils/install_cluster_library.py b/langgraph_agent_app_sample_code/cookbook/databricks_utils/install_cluster_library.py new file mode 100644 index 0000000..e7a0074 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/databricks_utils/install_cluster_library.py @@ -0,0 +1,107 @@ +from typing import List + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.compute import ( + Library, + LibraryFullStatus, + LibraryInstallStatus, + PythonPyPiLibrary, +) +import time + + +def parse_requirements(requirements_path: str) -> List[str]: + """Parse requirements.txt file and return list of package specifications.""" + packages = [] + with open(requirements_path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + packages.append(line) + return packages + + +def wait_for_library_installation( + w: WorkspaceClient, cluster_id: str, timeout_minutes: int = 20 +): + """Wait for all libraries to be installed or fail.""" + start_time = time.time() + timeout_seconds = timeout_minutes * 60 + final_states = { + LibraryInstallStatus.INSTALLED, + LibraryInstallStatus.FAILED, + LibraryInstallStatus.SKIPPED, + } + + while True: + if time.time() - start_time > timeout_seconds: + print( + f"Timeout after {timeout_minutes} minutes waiting for library installation" + ) + break + + status: List[LibraryFullStatus] = w.libraries.cluster_status(cluster_id) + all_finished = True + + for lib in status: + if lib.status not in final_states: + all_finished = False + break + + if all_finished: + break + + print("Installation in progress, waiting 15 seconds...") + time.sleep(15) # Check every 15 seconds + + # Print final status + status = w.libraries.cluster_status(cluster_id) + for lib in status: + if lib.library.pypi: + status_msg = ( + f"Package: {lib.library.pypi.package} - Status: {lib.status.value}" + ) + if lib.messages: + status_msg += f" - Messages: {', '.join(lib.messages)}" + print(status_msg) + + +def install_requirements(cluster_id: str, requirements_path: str): + """Install all packages from requirements.txt into specified cluster.""" + # Initialize workspace client + w = WorkspaceClient() + + # Parse requirements file + packages = parse_requirements(requirements_path) + + # Get current library status + current_status = w.libraries.cluster_status(cluster_id) + existing_packages = { + lib.library.pypi.package: lib.status.value + for lib in current_status + if lib.library.pypi + } + + # Filter out already installed packages + libraries = [] + for package in packages: + if ( + package not in existing_packages + or existing_packages[package] != LibraryInstallStatus.INSTALLED.value + ): + libraries.append(Library(pypi=PythonPyPiLibrary(package=package))) + else: + print(f"Package {package} is already installed, skipping...") + + if not libraries: + print("All packages are already installed.") + return + + # Install libraries + package_names = [lib.pypi.package for lib in libraries] + print(f"Installing {len(libraries)} packages: {', '.join(package_names)}") + w.libraries.install(cluster_id, libraries=libraries) + + # Wait for installation to complete + print("Waiting for installation to complete...") + wait_for_library_installation(w, cluster_id) diff --git a/langgraph_agent_app_sample_code/cookbook/tools/__init__.py b/langgraph_agent_app_sample_code/cookbook/tools/__init__.py new file mode 100644 index 0000000..6fc89bd --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/tools/__init__.py @@ -0,0 +1,45 @@ +from cookbook.config import SerializableConfig +from mlflow.models.resources import DatabricksResource + + +from typing import Any, List + + +class Tool(SerializableConfig): + """Base class for all tools""" + + def __call__(self, **kwargs) -> Any: + """Execute the tool with validated inputs""" + raise NotImplementedError( + "__call__ must be implemented by Tool subclasses. This method should execute " + "the tool's functionality with the provided validated inputs and return the result." + ) + + name: str + description: str + + def get_json_schema(self) -> dict: + """Returns an OpenAPI-compatible JSON schema for the tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self._get_parameters_schema(), + }, + } + + def _get_parameters_schema(self) -> dict: + """Returns the JSON schema for the tool's parameters.""" + raise NotImplementedError( + "_get_parameters_schema must be implemented by Tool subclasses. This method should " + "return an OpenAPI-compatible JSON schema dict describing the tool's input parameters. " + "The schema should include parameter names, types, descriptions, and any validation rules." + ) + + def get_resource_dependencies(self) -> List[DatabricksResource]: + """Returns a list of Databricks resources (mlflow.models.resources.* objects) that the tool uses. Used to securely provision credentials for these resources when the tool is deployed to Model Serving.""" + raise NotImplementedError( + "get_resource_dependencies must be implemented by Tool subclasses. This method should " + "return a list of mlflow.models.resources.* objects that the tool depends on." + ) diff --git a/langgraph_agent_app_sample_code/cookbook/tools/local_function.py b/langgraph_agent_app_sample_code/cookbook/tools/local_function.py new file mode 100644 index 0000000..afbc719 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/tools/local_function.py @@ -0,0 +1,165 @@ +from cookbook.tools import Tool + +from mlflow.models.resources import DatabricksResource +from pydantic import BaseModel, Field, create_model +from unitycatalog.ai.core.utils.docstring_utils import parse_docstring +from typing import Optional + +import inspect +from typing import Any, Callable, List, Type, get_type_hints +import importlib +import mlflow + + +class LocalFunctionTool(Tool): + """Tool implementation that wraps a function""" + + # func: Callable + func_path: str + name: str + description: str + _input_schema: Type[BaseModel] + + def _process_function( + self, func: Callable, name: Optional[str], description: Optional[str] + ) -> tuple[str, str, Type[BaseModel]]: + """Process a function to extract name, description and input schema. + + Args: + func: The function to process + name: Optional override for the function name + description: Optional override for the function description + + Returns: + Tuple of (processed_name, processed_description, processed_input_schema) + """ + processed_name = name or func.__name__ + + # Validate function has type annotations + if not all(get_type_hints(func).values()): + raise ValueError( + f"Tool '{processed_name}' must have complete type annotations for all parameters " + "and return value." + ) + + # Parse the docstring and get description + docstring = inspect.getdoc(func) + if not docstring: + raise ValueError( + f"Tool '{processed_name}' must have a docstring with Google-style formatting." + ) + + doc_info = parse_docstring(docstring) + processed_description = description or doc_info.description + + # Ensure we have parameter documentation + if not doc_info.params: + raise ValueError( + f"Tool '{processed_name}' must have documented parameters in Google-style format. " + "Example:\n Args:\n param_name: description" + ) + + # Validate all parameters are documented + sig_params = set(inspect.signature(func).parameters.keys()) + doc_params = set(doc_info.params.keys()) + if sig_params != doc_params: + missing = sig_params - doc_params + extra = doc_params - sig_params + raise ValueError( + f"Tool '{processed_name}' parameter documentation mismatch. " + f"Missing docs for: {missing if missing else 'none'}. " + f"Extra docs for: {extra if extra else 'none'}." + ) + + # Create the input schema + processed_input_schema = self._create_schema_from_function( + func, doc_info.params + ) + + return processed_name, processed_description, processed_input_schema + + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + *, + func: Optional[Callable] = None, + func_path: Optional[str] = None, + ): + if func is not None and func_path is not None: + raise ValueError("Only one of func or func_path can be provided") + + if func is not None: + # Process the function to get name, description and input schema + processed_name, processed_description, processed_input_schema = ( + self._process_function(func, name, description) + ) + + # Serialize the function's location + func_path = f"{func.__module__}.{func.__name__}" + + # Now call parent class constructor with processed values + super().__init__( + func_path=func_path, + name=processed_name, + description=processed_description, + ) + + self._input_schema = processed_input_schema + + self._loaded_callable = None + self.load_func() + elif func_path is not None: + + super().__init__( + func_path=func_path, + name=name, + description=description, + # _input_schema=None, + ) + + self._loaded_callable = None + self.load_func() + + _, _, processed_input_schema = self._process_function( + self._loaded_callable, name, description + ) + + self._input_schema = processed_input_schema + + @staticmethod + def _create_schema_from_function( + func: Callable, param_descriptions: dict[str, str] + ) -> Type[BaseModel]: + """Creates a Pydantic model from function signature and parsed docstring""" + sig = inspect.signature(func) + type_hints = get_type_hints(func) + + fields = {} + for name, param in sig.parameters.items(): + fields[name] = ( + type_hints.get(name, Any), + Field(description=param_descriptions.get(name, f"Parameter: {name}")), + ) + + return create_model(f"{func.__name__.title()}Inputs", **fields) + + def load_func(self): + if self._loaded_callable is None: + module_name, func_name = self.func_path.rsplit(".", 1) + module = importlib.import_module(module_name) + self._loaded_callable = getattr(module, func_name) + + @mlflow.trace(span_type="TOOL", name="local_function") + def __call__(self, **kwargs) -> Any: + """Execute the tool's function with validated inputs""" + self.load_func() + validated_inputs = self._input_schema(**kwargs) + return self._loaded_callable(**validated_inputs.model_dump()) + + def _get_parameters_schema(self) -> dict: + """Returns the JSON schema for the tool's parameters.""" + return self._input_schema.model_json_schema() + + def get_resource_dependencies(self) -> List[DatabricksResource]: + return [] diff --git a/langgraph_agent_app_sample_code/cookbook/tools/uc_tool.py b/langgraph_agent_app_sample_code/cookbook/tools/uc_tool.py new file mode 100644 index 0000000..aaba7d3 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/tools/uc_tool.py @@ -0,0 +1,172 @@ +from cookbook.tools import Tool +from cookbook.databricks_utils import get_function_url + + +from cookbook.tools.uc_tool_utils import ( + _parse_SparkException_from_tool_execution, + _parse_ParseException_from_tool_execution, +) +import mlflow +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ResourceDoesNotExist +from mlflow.models.resources import DatabricksFunction, DatabricksResource +from pydantic import Field, model_validator +from pyspark.errors import SparkRuntimeException +from pyspark.errors.exceptions.connect import ParseException +from unitycatalog.ai.core.databricks import DatabricksFunctionClient +from unitycatalog.ai.openai.toolkit import UCFunctionToolkit +from dataclasses import asdict + +import json +from typing import Any, Dict, List, Union + +ERROR_INSTRUCTIONS_KEY = "error_instructions" +ERROR_STATUS_KEY = "error" + + +class UCTool(Tool): + """Configuration for a Unity Catalog function tool. + + This class defines the configuration for a Unity Catalog function that can be used + as a tool in an agent system. + + Args: + uc_function_name: Unity Catalog location of the function in format: catalog.schema.function_name. + Example: my_catalog.my_schema.my_function + + Returns: + UCTool: A configured Unity Catalog function tool object. + """ + + uc_function_name: str + """Unity Catalog location of the function in format: catalog.schema.function_name.""" + + error_prompt: str = ( + f"""The tool call generated an Exception, detailed in `{ERROR_STATUS_KEY}`. Think step-by-step following these instructions to determine your next step.\n""" + "[1] Is the error due to a problem with the input parameters?\n" + "[2] Could it succeed if retried with exactly the same inputs?\n" + "[3] Could it succeed if retried with modified parameters using the input we already have from the user?\n" + "[4] Could it succeed if retried with modified parameters informed by collecting additional input from the user? What specific input would we need from the user?\n" + """Based on your thinking, if the error is due to a problem with the input parameters, either call this tool again in a way that avoids this exception or collect additional information from the user to modify the inputs to avoid this exception.""" + ) + + # Optional b/c we set these automatically in model_post_init from the UC function itself. + # Suggest not overriding these, but rather updating the UC function's metadata directly. + name: str = Field(default=None) # Make it optional in the constructor + description: str = Field(default=None) # Make it optional in the constructor + + @model_validator(mode="after") + def validate_uc_function_name(self) -> "UCTool": + """Validates that the UC function exists and is accessible. + + Checks that the function name is properly formatted and exists in Unity Catalog + with proper permissions. + + Returns: + UCTool: The validated tool instance. + + Raises: + ValueError: If function name is invalid or function is not accessible. + """ + parts = self.uc_function_name.split(".") + if len(parts) != 3: + raise ValueError( + f"uc_function_name must be in format: catalog.schema.function_name; got `{self.uc_function_name}`" + ) + + # Validate that the function exists in Unity Catalog & user has EXECUTE permission on the function + # Docs: https://databricks-sdk-py.readthedocs.io/en/stable/workspace/catalog/functions.html#get + w = WorkspaceClient() + try: + w.functions.get(name=self.uc_function_name) + except ResourceDoesNotExist: + raise ValueError( + f"Function `{self.uc_function_name}` not found in Unity Catalog or you do not have permission to access it. Ensure the function exists, and you have EXECUTE permission on the function, USE CATALOG and USE SCHEMA permissions on the catalog and schema. If function exists, you can verify permissions here: {get_function_url(self.uc_function_name)}." + ) + + return self + + def model_post_init(self, __context: Any) -> None: + + # Initialize the UC clients + self._uc_client = DatabricksFunctionClient() + self._toolkit = UCFunctionToolkit( + function_names=[self.uc_function_name], client=self._uc_client + ) + + # OK to use [0] position b/c we know that there is only one function initialized in the toolkit. + self.name = self._toolkit.tools[0]["function"]["name"] + self.description = self._toolkit.tools[0]["function"]["description"] + + def _get_parameters_schema(self) -> dict: + """Gets the parameter schema for the UC function. + + Returns: + dict: JSON schema describing the function's parameters. + """ + # OK to use [0] position b/c we know that there is only one function initialized in the toolkit. + return self._toolkit.tools[0]["function"]["parameters"] + + @mlflow.trace(span_type="TOOL", name="uc_tool") + def __call__(self, **kwargs) -> Dict[str, str]: + # annotate the span with the tool name + span = mlflow.get_current_active_span() + if span: # TODO: Hack, when mlflow tracing is disabled, span == None. + span.set_attributes({"uc_tool_name": self.uc_function_name}) + + # trace the function call + traced_exec_function = mlflow.trace( + span_type="FUNCTION", name="_uc_client.execute_function" + )(self._uc_client.execute_function) + + # convert input args to json + args_json = json.loads(json.dumps(kwargs, default=str)) + + # TODO: Add in Ben's code parser + + # Try to execute the function & return its value as a dict + try: + result = traced_exec_function( + function_name=self.uc_function_name, parameters=args_json + ) + return asdict(result) + + # Parse the error into a format that's easier for the LLM to understand w/ out any of the Spark runtime error noise + except SparkRuntimeException as tool_exception: + return { + ERROR_STATUS_KEY: _parse_SparkException_from_tool_execution( + tool_exception + ), + ERROR_INSTRUCTIONS_KEY: self.error_prompt, + } + except ParseException as tool_exception: + return { + ERROR_STATUS_KEY: _parse_ParseException_from_tool_execution( + tool_exception + ), + ERROR_INSTRUCTIONS_KEY: self.error_prompt, + } + except Exception as tool_exception: + # some other type of error that is unknown, parse into the same format as the Spark exceptions + # will first try to parse using the SparkException parsing code, if that fails, will then try the generic one + return { + ERROR_STATUS_KEY: _parse_SparkException_from_tool_execution( + tool_exception + ), + ERROR_INSTRUCTIONS_KEY: self.error_prompt, + } + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """Override model_dump to exclude name and description fields. + + Returns: + Dict[str, Any]: Dictionary representation of the model excluding name and description. + """ + kwargs["exclude"] = {"name", "description"}.union(kwargs.get("exclude", set())) + return super().model_dump(**kwargs) + + def get_resource_dependencies(self) -> List[DatabricksResource]: + return [DatabricksFunction(function_name=self.uc_function_name)] + + def _remove_udfbody_from_stack_trace(self, stack_trace: str) -> str: + return stack_trace.replace('File "",', "").strip() diff --git a/langgraph_agent_app_sample_code/cookbook/tools/uc_tool_utils.py b/langgraph_agent_app_sample_code/cookbook/tools/uc_tool_utils.py new file mode 100644 index 0000000..c4f7825 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/tools/uc_tool_utils.py @@ -0,0 +1,132 @@ +import mlflow +from pyspark.errors import SparkRuntimeException +from pyspark.errors.exceptions.connect import ParseException +import re + +import logging +from typing import Dict, Union + +ERROR_KEY = "error_message" +STACK_TRACE_KEY = "stack_trace" + + +@mlflow.trace(span_type="PARSER") +def _remove_udfbody_from_pyspark_stack_trace(stack_trace: str) -> str: + return stack_trace.replace('File "",', "").strip() + + +@mlflow.trace(span_type="PARSER") +def _parse_PySpark_exception_dumped_as_string(error_msg: str) -> Dict[str, str]: + # Extract error section between == Error == and == Stacktrace == + error = error_msg.split("== Error ==")[1].split("== Stacktrace ==")[0].strip() + + # Extract stacktrace section after == Stacktrace == and before SQL + stack_trace = error_msg.split("== Stacktrace ==")[1].split("== SQL")[0].strip() + + # Remove SQLSTATE and anything after it from the stack trace + if "SQLSTATE" in stack_trace: + stack_trace = stack_trace.split("SQLSTATE")[0].strip() + + return { + STACK_TRACE_KEY: _remove_udfbody_from_pyspark_stack_trace(stack_trace), + ERROR_KEY: error, + } + + +@mlflow.trace(span_type="PARSER") +def _parse_PySpark_exception_from_known_structure( + tool_exception: SparkRuntimeException, +) -> Dict[str, str]: + raw_stack_trace = tool_exception.getMessageParameters()["stack"] + return { + STACK_TRACE_KEY: _remove_udfbody_from_pyspark_stack_trace(raw_stack_trace), + ERROR_KEY: tool_exception.getMessageParameters()["error"], + } + + +@mlflow.trace(span_type="PARSER") +def _parse_generic_tool_exception(tool_exception: Exception) -> Dict[str, str]: + return { + STACK_TRACE_KEY: None, + ERROR_KEY: str(tool_exception), + } + + +@mlflow.trace(span_type="PARSER") +def _parse_SparkException_from_tool_execution( + tool_exception: Union[SparkRuntimeException, Exception], +) -> Dict[str, str]: + error_info_to_return: Union[Dict, str] = None + + # First attempt: first try to parse from the known structure + try: + logging.info( + f"Trying to parse spark exception {tool_exception} using its provided structured data." + ) + # remove the from the stack trace which the LLM knows nothing about + # raw_stack_trace = tool_exception.getMessageParameters()["stack"] + return _parse_PySpark_exception_from_known_structure(tool_exception) + + except Exception as e: + # 2nd attempt: that failed, let's try to parse the SparkException's raw formatting + logging.info( + f"Error parsing spark exception using its provided structured data: {e}, will now try to parse its string output..." + ) + + logging.info( + f"Trying to parse spark exception {tool_exception} using its raw string output." + ) + try: + raw_error_msg = str(tool_exception) + return _parse_PySpark_exception_dumped_as_string(raw_error_msg) + except Exception as e: + # Last attempt: if that fails, just use the raw error + logging.info( + f"Error parsing spark exception using its raw string formatting: {e}, will just return the raw error message." + ) + + logging.info(f"returning the raw error message: {str(tool_exception)}.") + return _parse_generic_tool_exception(tool_exception) + + +# TODO: this might be over fit to python code execution tool, need to test it more +@mlflow.trace(span_type="PARSER") +def _parse_ParseException_from_tool_execution( + tool_exception: ParseException, +) -> Dict[str, str]: + try: + error_msg = tool_exception.getMessage() + # Extract the main error message (remove SQLSTATE and position info) + error = error_msg.split("SQLSTATE:")[0].strip() + if "[PARSE_SYNTAX_ERROR]" in error: + error = error.split("[PARSE_SYNTAX_ERROR]")[1].strip() + + # Pattern to match "line X, pos Y" + pattern = r"line (\d+), pos (\d+)" + match = re.search(pattern, error_msg) + + if match: + line_num = match.group(1) + pos_num = match.group(2) + line_info = f"(line {line_num}, pos {pos_num})" + error = error + " " + line_info + + # Extract the SQL section with the error pointer + sql_section = ( + error_msg.split("== SQL ==")[1].split("JVM stacktrace:")[0].strip() + if "== SQL ==" in error_msg + else "" + ) + + # Remove the SELECT statement from the error message + select_pattern = r"SELECT\s+`[^`]+`\.`[^`]+`\.`[^`]+`\('" + # error_without_sql_parts = sql_section.replace(select_pattern, "").strip() + error_without_sql_parts = re.sub(select_pattern, "", sql_section).strip() + + return {STACK_TRACE_KEY: error_without_sql_parts, ERROR_KEY: error} + except Exception as e: + logging.info(f"Error parsing ParseException: {e}") + return { + STACK_TRACE_KEY: None, + ERROR_KEY: str(tool_exception), + } diff --git a/langgraph_agent_app_sample_code/cookbook/tools/vector_search.py b/langgraph_agent_app_sample_code/cookbook/tools/vector_search.py new file mode 100644 index 0000000..d2c82c6 --- /dev/null +++ b/langgraph_agent_app_sample_code/cookbook/tools/vector_search.py @@ -0,0 +1,455 @@ +import mlflow +from mlflow.entities import Document +from mlflow.models.resources import ( + DatabricksVectorSearchIndex, + DatabricksServingEndpoint, + DatabricksResource, +) + +import json +from typing import Literal, Any, Dict, List, Union +from pydantic import BaseModel, model_validator +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.vectorsearch import VectorIndexType +from databricks.sdk.errors import ResourceDoesNotExist +from cookbook.tools import Tool +from dataclasses import asdict + +FilterDict = Dict[str, Union[str, int, float, List[Union[str, int, float]]]] + +# Change this to True to use the source table's metadata for the filterable columns. +# This causes deployment to fail since the deployed model doesn't have access to the source table. +USE_SOURCE_TABLE_FOR_FILTERS_METADATA = False + + +class VectorSearchSchema(BaseModel): + """Configuration for the schema used in the retriever's response. + + This class defines the schema configuration for how the vector search retriever + structures and returns results. + + Args: + primary_key: The column name in the retriever's response referred to the unique key. + If using Databricks vector search with delta sync, this should be the column + of the delta table that acts as the primary key. + chunk_text: The column name in the retriever's response that contains the + returned chunk. + document_uri: The template of the chunk returned by the retriever - used to format + the chunk for presentation to the LLM & to display chunk's from the same + document_uri together in Agent Evaluation Review App. + additional_metadata_columns: Additional metadata columns to present to the LLM. + filterable_columns: List of columns that can be used as filters by the LLM. + + Returns: + VectorSearchSchema: A configured schema object for the vector search retriever. + """ + + _primary_key: str | None = None + """The column name in the retriever's response referred to the unique key. + If using Databricks vector search with delta sync, this should be the column + of the delta table that acts as the primary key, and will be set by reading the index's metadata.""" + + chunk_text: str + """The column name in the retriever's response that contains the returned chunk.""" + + document_uri: str + """The template of the chunk returned by the retriever - used to format + the chunk for presentation to the LLM & to display chunk's from the same + document_uri together in Agent Evaluation Review App.""" + + additional_metadata_columns: List[str] = [] + """Additional metadata columns to present to the LLM.""" + + @property + def all_columns(self) -> List[str]: + cols = [ + self.primary_key, + self.chunk_text, + self.document_uri, + ] + self.additional_metadata_columns + # de-duplicate + return list(set(cols)) + + @property + def primary_key(self) -> str: + """The primary key field, which must be set by VectorSearchRetrieverConfig""" + if self._primary_key is None: + raise ValueError("primary_key must be set by VectorSearchRetrieverConfig") + return self._primary_key + + +class VectorSearchParameters(BaseModel): + """Configuration for the input schema (parameters) used in the retriever. + + This class defines the configuration parameters for how the vector search retriever + performs searches and returns results. + + Args: + num_results: The number of chunks to return for each query. For example, + setting this to 5 will return the top 5 most relevant search results. + query_type: The type of search to use - either 'ann' for semantic similarity + using embeddings only, or 'hybrid' which combines keyword and semantic + similarity search. + + Returns: + VectorSearchParameters: A configured parameters object for the vector search retriever. + """ + + num_results: int = 5 + """The number of chunks to return for each query.""" + + query_type: Literal["ann", "hybrid"] = "ann" + """The type of search to use - either 'ann' for semantic similarity using embeddings only, + or 'hybrid' which combines keyword and semantic similarity search.""" + + +class VectorSearchRetrieverTool(Tool): + """Configuration for a Databricks Vector Search retriever. + + This class defines the configuration for a Vector Search retriever that can be used + either deterministically in a fixed RAG chain or as a tool. + + Args: + vector_search_index: Unity Catalog location of the Vector Search index. + Example: catalog.schema.vector_index. + vector_search_schema: Schema configuration for the retriever. + doc_similarity_threshold: Threshold (0-1) for the retrieved document's similarity score. Used + to exclude dissimilar results. Increase if retriever returns irrelevant content. + vector_search_parameters: Parameters passed to index.similarity_search(...). + See https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#query-a-vector-search-endpoint for details. + retriever_query_parameter_prompt: Description of the query parameter for the retriever. + + Returns: + VectorSearchRetrieverConfig: A configured retriever config object. + """ + + vector_search_index: str + """Unity Catalog location of the Vector Search index. + Example: catalog.schema.vector_index.""" + + filterable_columns: List[str] = [] + """List of columns that can be used as filters by the LLM. Columns will be validated against the source table & metadata about each column loaded from the Unity Catalog to improve the LLM's ability to filter.""" + + vector_search_schema: VectorSearchSchema + """Schema configuration for the retriever.""" + + doc_similarity_threshold: float = 0.0 + """Threshold (0-1) for the retrieved document's similarity score. + Used to exclude dissimilar results. Increase if retriever returns irrelevant content.""" + + vector_search_parameters: VectorSearchParameters = VectorSearchParameters() + """Parameters passed to index.similarity_search(...). + See https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#query-a-vector-search-endpoint for details.""" + + retriever_query_parameter_prompt: str = "query to look up in retriever" + retriever_filter_parameter_prompt: str = ( + "optional filters to apply to the search. An array of objects, each specifying a field name and the filters to apply to that field." + ) + + name: str + description: str + + def __init__(self, **data): + """Initialize the WorkspaceClient and set the MLflow retriever schema.""" + super().__init__(**data) + mlflow.models.set_retriever_schema( + name=self.vector_search_index, + primary_key=self.vector_search_schema.primary_key, + text_column=self.vector_search_schema.chunk_text, + doc_uri=self.vector_search_schema.document_uri, + ) + + def _validate_columns_exist( + self, columns: List[str], source_table: str, table_columns: set, context: str + ) -> None: + """Helper method to validate that columns exist in the source table. + + Args: + columns: List of columns to validate + source_table: Name of the source table + table_columns: Set of available columns in the table + context: Context string for error message (e.g. "filterable columns", "chunk_text") + """ + for col in columns: + if col not in table_columns: + raise ValueError( + f"Column '{col}' specified in {context} not found in source table {source_table}. " + f"Available columns: {', '.join(sorted(table_columns))}" + ) + + def _get_index_info(self): + w = WorkspaceClient() + return w.vector_search_indexes.get_index(self.vector_search_index) + + def _check_if_index_exists(self): + w = WorkspaceClient() + try: + index_info = w.vector_search_indexes.get_index(self.vector_search_index) + return index_info is not None + except ResourceDoesNotExist as e: + return False + + @property + def filterable_columns_descriptions_for_llm(self) -> str: + """Returns a formatted description of all filterable columns for use in prompts.""" + if USE_SOURCE_TABLE_FOR_FILTERS_METADATA: + # Present the LLM with the source table's metadata for the filterable columns. + # TODO: be able to get this data directly from the index's metadata + # Get source table info + index_info = self._get_index_info() + if index_info.index_type != VectorIndexType.DELTA_SYNC: + raise ValueError( + f"Unsupported index type: {index_info.index_type}. Only DELTA_SYNC is supported." + ) + + w = WorkspaceClient() + source_table = index_info.delta_sync_index_spec.source_table + table_info = w.tables.get(source_table) + + # Create mapping of column name to description and type + column_info = { + col.name: (col.type_text, col.comment if col.comment else None) + for col in table_info.columns + } + # print(column_info) + + # Build descriptions list + descriptions = [] + for col in self.filterable_columns: + type_text, desc = column_info.get(col, (None, None)) + formatted_desc = f"(`{col}`, {type_text}" + ( + f", '{desc}'" + ")" if desc else "" + ) + descriptions.append(formatted_desc) + return ", ".join(descriptions) + + else: + # just use the column names as metadata + return ", ".join(str(col) for col in self.filterable_columns) + + @model_validator(mode="after") + def validate_index_and_columns(self): + """Validates the index exists and all columns after the model is fully initialized""" + + # Check that index exists + if not self._check_if_index_exists(): + raise ValueError( + f"Vector search index {self.vector_search_index} does not exist." + ) + + index_info = self._get_index_info() + + # Set primary key from index if not already set + if not self.vector_search_schema._primary_key: + if index_info.primary_key: + self.vector_search_schema._primary_key = index_info.primary_key + else: + raise ValueError( + f"Could not find primary key in index {self.vector_search_index}" + ) + + # TODO: Validate all configured schema columns exist in the index. Currently, this data is not available in the index metadata. + + return self + + @model_validator(mode="after") + def validate_threshold(self): + if not 0 <= self.doc_similarity_threshold <= 1: + raise ValueError("doc_similarity_threshold must be between 0 and 1") + return self + + def _get_parameters_schema(self) -> dict: + schema = { + "type": "object", + "required": ["query"], + "additionalProperties": False, + "properties": { + "query": { + # "default": None, + "description": self.retriever_query_parameter_prompt, + "type": "string", + }, + }, + } + + if self.filterable_columns: + schema["properties"]["filters"] = { + # "default": None, + "description": self.retriever_filter_parameter_prompt, + "type": "array", + "items": { + "type": "object", + "properties": { + "field": { + "type": "string", + "enum": self.filterable_columns, + "description": "The fields to apply the filter to. Can use any of the following as filters, where each is (`field_name`, field_type, 'field_description'): " + + self.filterable_columns_descriptions_for_llm + + "For string fields, only use LIKE filter; for numeric fields, either provide a number to achieve == or use <, <=, >, >= filters; for array fields, either provide an array of 1+ values to achieve IN or use NOT to exclude.", + }, + "filter": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + { + "type": "array", + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + ] + }, + }, + { + "type": "object", + "properties": { + "<": {"type": "number"}, + "<=": {"type": "number"}, + ">": {"type": "number"}, + ">=": {"type": "number"}, + "LIKE": {"type": "string"}, + "NOT": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + ] + }, + }, + "additionalProperties": False, + "minProperties": 1, + "maxProperties": 1, + }, + ] + }, + }, + "required": ["field", "filter"], + "additionalProperties": False, + }, + } + + return schema + + @mlflow.trace(span_type="RETRIEVER", name="vector_search_retriever") + def __call__(self, query: str, filters: Dict[Any, Any] = None) -> List[Document]: + """ + Performs vector search to retrieve relevant chunks. + + Args: + query: Search query. + filters: Optional filters to apply to the search. Should follow the LLM-generated filter pattern of a list of field/filter pairs that will be converted to Databricks Vector Search filter format. + + Returns: + List of retrieved Documents. + """ + span = mlflow.get_current_active_span() + if span: # TODO: Hack, when mlflow tracing is disabled, span == None. + span.set_attributes({"vector_search_index": self.vector_search_index}) + + w = WorkspaceClient() + + traced_search = mlflow.trace( + w.vector_search_indexes.query_index, + name="_workspace_client.vector_search_indexes.query_index", + span_type="FUNCTION", + ) + + # Parse filters written by the LLM into Vector Search compatible format + vs_filters = json.dumps(self.parse_filters(filters)) if filters else None + + results = traced_search( + index_name=self.vector_search_index, + query_text=query, + filters_json=vs_filters, + columns=self.vector_search_schema.all_columns, + **self.vector_search_parameters.model_dump(exclude_none=True), + ) + + # We turn the config into a dict and pass it here + return self.convert_vector_search_to_documents( + results.as_dict(), self.doc_similarity_threshold + ) + + @mlflow.trace(span_type="PARSER") + def convert_vector_search_to_documents( + self, vs_results, vector_search_threshold + ) -> List[Document]: + column_names = [] + for column in vs_results["manifest"]["columns"]: + column_names.append(column) + + docs = [] + if vs_results["result"]["row_count"] > 0: + for item in vs_results["result"]["data_array"]: + metadata = {} + score = item[-1] + if score >= vector_search_threshold: + metadata["similarity_score"] = score + for i, field in enumerate(item[0:-1]): + metadata[column_names[i]["name"]] = field + # put contents of the chunk into page_content + page_content = metadata[self.vector_search_schema.chunk_text] + del metadata[self.vector_search_schema.chunk_text] + + # put the primary key into id + id = metadata[self.vector_search_schema.primary_key] + del metadata[self.vector_search_schema.primary_key] + + doc = Document(page_content=page_content, metadata=metadata, id=id) + docs.append(asdict(doc)) + + return docs + + @mlflow.trace(span_type="PARSER") + def parse_filters(self, filters: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Parse input filters into Vector Search compatible format. + + Args: + filters: List of input filters in the new format. + + Returns: + Filters in Vector Search compatible format. + """ + vs_filters = {} + for filter_item in filters: + suggested_field = filter_item["field"] + suggested_filter = filter_item["filter"] + + if isinstance(suggested_filter, list): + # vs_filters[key] = {"OR": value} + vs_filters[suggested_field] = suggested_filter + elif isinstance(suggested_filter, dict): + operator, operand = next(iter(suggested_filter.items())) + vs_filters[suggested_field + " " + operator] = operand + # if operator in ["<", "<=", ">", ">="]: + # vs_filters[f"{key} {operator}"] = operand + # elif operator.upper() == "LIKE": + # vs_filters[f"{key} LIKE"] = operand + # elif operator.upper() == "NOT": + # vs_filters[f"{key} !="] = operand + else: + vs_filters[suggested_field] = suggested_filter + return vs_filters + + def get_resource_dependencies(self) -> List[DatabricksResource]: + dependencies = [ + DatabricksVectorSearchIndex(index_name=self.vector_search_index) + ] + + # Get the embedding model endpoint + index_info = self._get_index_info() + if index_info.index_type == VectorIndexType.DELTA_SYNC: + # Only DELTA_SYNC indexes have embedding model endpoints + for ( + embedding_source_col + ) in index_info.delta_sync_index_spec.embedding_source_columns: + endpoint_name = embedding_source_col.embedding_model_endpoint_name + if endpoint_name is not None: + dependencies.append( + DatabricksServingEndpoint(endpoint_name=endpoint_name), + ) + else: + print( + f"Could not identify the embedding model endpoint resource for {self.vector_search_index}. Please manually add the embedding model endpoint to `databricks_resources`." + ) + return dependencies diff --git a/langgraph_agent_app_sample_code/environment.yaml b/langgraph_agent_app_sample_code/environment.yaml new file mode 100644 index 0000000..76883b2 --- /dev/null +++ b/langgraph_agent_app_sample_code/environment.yaml @@ -0,0 +1,4 @@ +client: "1" +dependencies: + - --index-url https://pypi.org/simple + - -r requirements.txt \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/pyproject.toml b/langgraph_agent_app_sample_code/pyproject.toml new file mode 100644 index 0000000..fa57fb0 --- /dev/null +++ b/langgraph_agent_app_sample_code/pyproject.toml @@ -0,0 +1,36 @@ +[tool.poetry] +name = "genai-cookbook" +version = "0.1.0" +description = "" +authors = ["Eric Peter "] +readme = "README.md" +packages = [{include = "cookbook"}] + +[tool.poetry.dependencies] +python = "^3.11" +databricks-connect = "15.1.0" +pydantic = "^2.9.2" +pyyaml = "^6.0.2" +databricks-vectorsearch = "^0.42" +databricks-sdk = {extras = ["openai"], version = "^0.36.0"} +mlflow = "^2.18.0" +databricks-agents = "^0.10.0" +pymupdf4llm = "0.0.5" +pymupdf = "1.24.13" +markdownify = "0.12.1" +transformers = "4.41.1" +torch = "2.3.0" +tiktoken = "0.7.0" +langchain-text-splitters = "0.2.0" +ipykernel = "^6.29.5" +hatchling = "^1.25.0" +pypandoc-binary = "1.13" +tabulate = "^0.9.0" +ipywidgets = "^8.1.5" +unitycatalog-ai = {git = "https://github.com/unitycatalog/unitycatalog.git", subdirectory = "ai/core"} +unitycatalog-openai = {git = "https://github.com/unitycatalog/unitycatalog.git", subdirectory = "ai/integrations/openai"} +pytest = "^8.3.3" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/langgraph_agent_app_sample_code/requirements.txt b/langgraph_agent_app_sample_code/requirements.txt new file mode 100644 index 0000000..ee27b9d --- /dev/null +++ b/langgraph_agent_app_sample_code/requirements.txt @@ -0,0 +1,19 @@ +pydantic>=2.9.2 +databricks-agents +mlflow>=2.18.0 +databricks-sdk[openai] +databricks-vectorsearch +pyyaml +langgraph +langchain_core +databricks-langchain +langchain-community +git+https://github.com/unitycatalog/unitycatalog.git#subdirectory=ai/integrations/langchain +git+https://github.com/unitycatalog/unitycatalog.git#subdirectory=ai/core +git+https://github.com/unitycatalog/unitycatalog.git#subdirectory=ai/integrations/openai +tabulate +pandas +pyspark +databricks-connect==15.1.0 +python-box +pytest diff --git a/langgraph_agent_app_sample_code/requirements_datapipeline.txt b/langgraph_agent_app_sample_code/requirements_datapipeline.txt new file mode 100644 index 0000000..c7d2e90 --- /dev/null +++ b/langgraph_agent_app_sample_code/requirements_datapipeline.txt @@ -0,0 +1,9 @@ +pymupdf4llm==0.0.5 +pymupdf==1.24.13 +markdownify==0.12.1 +transformers==4.41.1 +torch==2.3.0 +tiktoken==0.7.0 +langchain-text-splitters==0.2.0 +pypandoc_binary==1.13 +pyyaml \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/tools/README.md b/langgraph_agent_app_sample_code/tools/README.md new file mode 100644 index 0000000..e7acbf9 --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/README.md @@ -0,0 +1 @@ +Store user-created tools in this directory. \ No newline at end of file diff --git a/langgraph_agent_app_sample_code/tools/__init__.py b/langgraph_agent_app_sample_code/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/langgraph_agent_app_sample_code/tools/code_exec.py b/langgraph_agent_app_sample_code/tools/code_exec.py new file mode 100644 index 0000000..633a34b --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/code_exec.py @@ -0,0 +1,20 @@ +def python_exec(code: str) -> str: + """ + Executes Python code in the sandboxed environment and returns its stdout. The runtime is stateless and you can not read output of the previous tool executions. i.e. No such variables "rows", "observation" defined. Calling another tool inside a Python code is NOT allowed. + Use only standard python libraries and these python libraries: bleach, chardet, charset-normalizer, defusedxml, googleapis-common-protos, grpcio, grpcio-status, jmespath, joblib, numpy, packaging, pandas, patsy, protobuf, pyarrow, pyparsing, python-dateutil, pytz, scikit-learn, scipy, setuptools, six, threadpoolctl, webencodings, user-agents, cryptography. + + Args: + code (str): Python code to execute. Remember to print the final result to stdout. + + Returns: + str: The output of the executed code. + """ + import sys + from io import StringIO + + sys_stdout = sys.stdout + redirected_output = StringIO() + sys.stdout = redirected_output + exec(code) + sys.stdout = sys_stdout + return redirected_output.getvalue() diff --git a/langgraph_agent_app_sample_code/tools/sample_tool.py b/langgraph_agent_app_sample_code/tools/sample_tool.py new file mode 100644 index 0000000..eef313c --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/sample_tool.py @@ -0,0 +1,46 @@ + +def sku_sample_translator(old_sku: str) -> str: + """ + Translates a pre-2024 SKU formatted as "OLD-XXX-YYYY" to the new SKU format "NEW-YYYY-XXX". + + Args: + old_sku (str): The old SKU in the format "OLD-XXX-YYYY". + + Returns: + str: The new SKU in the format "NEW-YYYY-XXX". + + Raises: + ValueError: If the SKU format is invalid, providing specific error details. + """ + import re + + if not isinstance(old_sku, str): + raise ValueError("SKU must be a string") + + # Normalize input by removing extra whitespace and converting to uppercase + old_sku = old_sku.strip().upper() + + # Define the regex pattern for the old SKU format + pattern = r"^OLD-([A-Z]{3})-(\d{4})$" + + # Match the old SKU against the pattern + match = re.match(pattern, old_sku) + if not match: + if not old_sku.startswith("OLD-"): + raise ValueError("SKU must start with 'OLD-'") + if not re.match(r"^OLD-[A-Z]{3}-\d{4}$", old_sku): + raise ValueError( + "SKU format must be 'OLD-XXX-YYYY' where X is a letter and Y is a digit" + ) + raise ValueError("Invalid SKU format") + + # Extract the letter code and numeric part + letter_code, numeric_part = match.groups() + + # Additional validation for numeric part + if not (1 <= int(numeric_part) <= 9999): + raise ValueError("Numeric part must be between 0001 and 9999") + + # Construct the new SKU + new_sku = f"NEW-{numeric_part}-{letter_code}" + return new_sku diff --git a/langgraph_agent_app_sample_code/tools/test_code_exec.py b/langgraph_agent_app_sample_code/tools/test_code_exec.py new file mode 100644 index 0000000..a4c5418 --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/test_code_exec.py @@ -0,0 +1,89 @@ + +import pytest +from .code_exec import python_exec + + +def test_basic_arithmetic(): + code = """result = 2 + 2\nprint(result)""" + assert python_exec(code).strip() == "4" + + +def test_multiple_lines(): + code = "x = 5\n" "y = 3\n" "result = x * y\n" "print(result)" + assert python_exec(code).strip() == "15" + + +def test_multiple_prints(): + code = """print('first')\nprint('second')\nprint('third')\n""" + expected = "first\nsecond\nthird\n" + assert python_exec(code) == expected + + +def test_using_pandas(): + code = ( + "import pandas as pd\n" + "data = {'col1': [1, 2], 'col2': [3, 4]}\n" + "df = pd.DataFrame(data)\n" + "print(df.shape)" + ) + assert python_exec(code).strip() == "(2, 2)" + + +def test_using_numpy(): + code = "import numpy as np\n" "arr = np.array([1, 2, 3])\n" "print(arr.mean())" + assert python_exec(code).strip() == "2.0" + + +def test_syntax_error(): + code = "if True\n" " print('invalid syntax')" + with pytest.raises(SyntaxError): + python_exec(code) + + +def test_runtime_error(): + code = "x = 1 / 0\n" "print(x)" + with pytest.raises(ZeroDivisionError): + python_exec(code) + + +def test_undefined_variable(): + code = "print(undefined_variable)" + with pytest.raises(NameError): + python_exec(code) + + +def test_multiline_string_manipulation(): + code = "text = '''\n" "Hello\n" "World\n" "'''\n" "print(text.strip())" + expected = "Hello\nWorld" + assert python_exec(code).strip() == expected + +# Will not fail locally, but will fail in UC. +# def test_unauthorized_flask(): +# code = "from flask import Flask\n" "app = Flask(__name__)\n" "print(app)" +# with pytest.raises(ImportError): +# python_exec(code) + + +def test_no_print_statement(): + code = "x = 42\n" "y = x * 2" + assert python_exec(code) == "" + + +def test_calculation_without_print(): + code = "result = sum([1, 2, 3, 4, 5])\n" "squared = [x**2 for x in range(5)]" + assert python_exec(code) == "" + + +def test_function_definition_without_call(): + code = "def add(a, b):\n" " return a + b\n" "result = add(3, 4)" + assert python_exec(code) == "" + + +def test_class_definition_without_instantiation(): + code = ( + "class Calculator:\n" + " def add(self, a, b):\n" + " return a + b\n" + "calc = Calculator()" + ) + assert python_exec(code) == "" diff --git a/langgraph_agent_app_sample_code/tools/test_code_exec_as_uc_tool.py b/langgraph_agent_app_sample_code/tools/test_code_exec_as_uc_tool.py new file mode 100644 index 0000000..cb3efb0 --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/test_code_exec_as_uc_tool.py @@ -0,0 +1,102 @@ + +import pytest +from cookbook.tools.uc_tool import UCTool + +CATALOG = "shared" # Change me! +SCHEMA = "cookbook_langgraph_udhay" # Change me if you want + + +@pytest.fixture +def python_exec(): + """Fixture to provide the python_exec function from UCTool.""" + python_exec_tool = UCTool(uc_function_name=f"{CATALOG}.{SCHEMA}.python_exec") + return python_exec_tool + + +def test_basic_arithmetic(python_exec): + code = """result = 2 + 2\nprint(result)""" + assert python_exec(code=code)["value"].strip() == "4" + + +def test_multiple_lines(python_exec): + code = "x = 5\n" "y = 3\n" "result = x * y\n" "print(result)" + assert python_exec(code=code)["value"].strip() == "15" + + +def test_multiple_prints(python_exec): + code = """print('first')\nprint('second')\nprint('third')\n""" + expected = "first\nsecond\nthird\n" + assert python_exec(code=code)["value"] == expected + + +def test_using_pandas(python_exec): + code = ( + "import pandas as pd\n" + "data = {'col1': [1, 2], 'col2': [3, 4]}\n" + "df = pd.DataFrame(data)\n" + "print(df.shape)" + ) + assert python_exec(code=code)["value"].strip() == "(2, 2)" + + +def test_using_numpy(python_exec): + code = "import numpy as np\n" "arr = np.array([1, 2, 3])\n" "print(arr.mean())" + assert python_exec(code=code)["value"].strip() == "2.0" + + +def test_syntax_error(python_exec): + code = "if True\n" " print('invalid syntax')" + result = python_exec(code=code) + assert "Syntax error at or near 'invalid'." in result["error"]["error_message"] + + +def test_runtime_error(python_exec): + code = "x = 1 / 0\n" "print(x)" + result = python_exec(code=code) + assert "ZeroDivisionError" in result["error"]["error_message"] + + +def test_undefined_variable(python_exec): + code = "print(undefined_variable)" + result = python_exec(code=code) + assert "NameError" in result["error"]["error_message"] + + +def test_multiline_string_manipulation(python_exec): + code = "text = '''\n" "Hello\n" "World\n" "'''\n" "print(text.strip())" + expected = "Hello\nWorld" + assert python_exec(code=code)["value"].strip() == expected + + +def test_unauthorized_flask(python_exec): + code = "from flask import Flask\n" "app = Flask(__name__)\n" "print(app)" + result = python_exec(code=code) + assert ( + "ModuleNotFoundError: No module named 'flask'" + in result["error"]["error_message"] + ) + + +def test_no_print_statement(python_exec): + code = "x = 42\n" "y = x * 2" + assert python_exec(code=code)["value"] == "" + + +def test_calculation_without_print(python_exec): + code = "result = sum([1, 2, 3, 4, 5])\n" "squared = [x**2 for x in range(5)]" + assert python_exec(code=code)["value"] == "" + + +def test_function_definition_without_call(python_exec): + code = "def add(a, b):\n" " return a + b\n" "result = add(3, 4)" + assert python_exec(code=code)["value"] == "" + + +def test_class_definition_without_instantiation(python_exec): + code = ( + "class Calculator:\n" + " def add(self, a, b):\n" + " return a + b\n" + "calc = Calculator()" + ) + assert python_exec(code=code)["value"] == "" diff --git a/langgraph_agent_app_sample_code/tools/test_sample_tool.py b/langgraph_agent_app_sample_code/tools/test_sample_tool.py new file mode 100644 index 0000000..f818d70 --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/test_sample_tool.py @@ -0,0 +1,52 @@ +import pytest +from tools.sample_tool import sku_sample_translator + + + +def test_valid_sku_translation(): + """Test successful SKU translation with valid input.""" + assert sku_sample_translator("OLD-ABC-1234") == "NEW-1234-ABC" + assert sku_sample_translator("OLD-XYZ-0001") == "NEW-0001-XYZ" + assert sku_sample_translator("old-def-5678") == "NEW-5678-DEF" # Test case insensitivity + + +def test_whitespace_handling(): + """Test that the function handles extra whitespace correctly.""" + assert sku_sample_translator(" OLD-ABC-1234 ") == "NEW-1234-ABC" + assert sku_sample_translator("\tOLD-ABC-1234\n") == "NEW-1234-ABC" + + +def test_invalid_input_type(): + """Test that non-string inputs raise ValueError.""" + with pytest.raises(ValueError, match="SKU must be a string"): + sku_sample_translator(123) + with pytest.raises(ValueError, match="SKU must be a string"): + sku_sample_translator(None) + + +def test_invalid_prefix(): + """Test that SKUs not starting with 'OLD-' raise ValueError.""" + with pytest.raises(ValueError, match="SKU must start with 'OLD-'"): + sku_sample_translator("NEW-ABC-1234") + with pytest.raises(ValueError, match="SKU must start with 'OLD-'"): + sku_sample_translator("XXX-ABC-1234") + + +def test_invalid_format(): + """Test various invalid SKU formats.""" + invalid_skus = [ + "OLD-AB-1234", # Too few letters + "OLD-ABCD-1234", # Too many letters + "OLD-123-1234", # Numbers instead of letters + "OLD-ABC-123", # Too few digits + "OLD-ABC-12345", # Too many digits + "OLD-ABC-XXXX", # Letters instead of numbers + "OLD-A1C-1234", # Mixed letters and numbers in middle + ] + + for sku in invalid_skus: + with pytest.raises( + ValueError, + match="SKU format must be 'OLD-XXX-YYYY' where X is a letter and Y is a digit", + ): + sku_sample_translator(sku) diff --git a/langgraph_agent_app_sample_code/tools/test_sample_tool_uc.py b/langgraph_agent_app_sample_code/tools/test_sample_tool_uc.py new file mode 100644 index 0000000..f1c8352 --- /dev/null +++ b/langgraph_agent_app_sample_code/tools/test_sample_tool_uc.py @@ -0,0 +1,75 @@ +import pytest +from cookbook.tools.uc_tool import UCTool + +CATALOG = "shared" # Change me! +SCHEMA = "cookbook_langgraph_udhay" # Change me if you want + +# Load the function from the UCTool versus locally +@pytest.fixture +def uc_tool(): + """Fixture to translate a UC tool into a local function.""" + UC_FUNCTION_NAME = f"{CATALOG}.{SCHEMA}.sku_sample_translator" + loaded_tool = UCTool(uc_function_name=UC_FUNCTION_NAME) + return loaded_tool + + +# Note: The value will be post processed into the `value` key, so we must check the returned value there. +def test_valid_sku_translation(uc_tool): + """Test successful SKU translation with valid input.""" + assert uc_tool(old_sku="OLD-ABC-1234")["value"] == "NEW-1234-ABC" + assert uc_tool(old_sku="OLD-XYZ-0001")["value"] == "NEW-0001-XYZ" + assert ( + uc_tool(old_sku="old-def-5678")["value"] == "NEW-5678-DEF" + ) # Test case insensitivity + + +# Note: The value will be post processed into the `value` key, so we must check the returned value there. +def test_whitespace_handling(uc_tool): + """Test that the function handles extra whitespace correctly.""" + assert uc_tool(old_sku=" OLD-ABC-1234 ")["value"] == "NEW-1234-ABC" + assert uc_tool(old_sku="\tOLD-ABC-1234\n")["value"] == "NEW-1234-ABC" + + +# Note: the input validation happens BEFORE the function is called by Spark, so we will never get these exceptions from the function. +# Instead, we will get invalid parameters errors from Spark. +def test_invalid_input_type(uc_tool): + """Test that non-string inputs raise ValueError.""" + assert ( + uc_tool(old_sku=123)["error"]["error_message"] + == """Invalid parameters provided: {'old_sku': "Parameter old_sku should be of type STRING (corresponding python type ), but got "}.""" + ) + assert ( + uc_tool(old_sku=None)["error"]["error_message"] + == """Invalid parameters provided: {'old_sku': "Parameter old_sku should be of type STRING (corresponding python type ), but got "}.""" + ) + + +# Note: The errors will be post processed into the `error_message` key inside the `error` top level key, so we must check for exceptions there. +def test_invalid_prefix(uc_tool): + """Test that SKUs not starting with 'OLD-' raise ValueError.""" + assert ( + uc_tool(old_sku="NEW-ABC-1234")["error"]["error_message"] + == "ValueError: SKU must start with 'OLD-'" + ) + assert ( + uc_tool(old_sku="XXX-ABC-1234")["error"]["error_message"] + == "ValueError: SKU must start with 'OLD-'" + ) + + +# Note: The errors will be post processed into the `error_message` key inside the `error` top level key, so we must check for exceptions there. +def test_invalid_format(uc_tool): + """Test various invalid SKU formats.""" + invalid_skus = [ + "OLD-AB-1234", # Too few letters + "OLD-ABCD-1234", # Too many letters + "OLD-123-1234", # Numbers instead of letters + "OLD-ABC-123", # Too few digits + "OLD-ABC-12345", # Too many digits + "OLD-ABC-XXXX", # Letters instead of numbers + "OLD-A1C-1234", # Mixed letters and numbers in middle + ] + + expected_error = "ValueError: SKU format must be 'OLD-XXX-YYYY' where X is a letter and Y is a digit" + for sku in invalid_skus: + assert uc_tool(old_sku=sku)["error"]["error_message"] == expected_error