Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 3168a18db21bbf34b908931ca2036478782126a3
Merge: 5010284 c93a9c2
Author: anubbhav-malhotra <[email protected]>
Date:   Thu Jan 16 15:34:35 2025 -0600

    Merge branch 'main' of https://github.com/GoogleCloudPlatform/vertex-ai-samples into GoogleCloudPlatform-main

commit c93a9c2
Author: Vertex MG Team <[email protected]>
Date:   Wed Jan 15 12:45:28 2025 -0800

    Create Cloud translation and evaluation demo notebook

    PiperOrigin-RevId: 715907272

commit 8ece5ef
Author: Vertex MG Team <[email protected]>
Date:   Wed Jan 15 03:08:03 2025 -0800

    Segment Anything Model (SAM) Serving on Vertex AI Notebook

    PiperOrigin-RevId: 715723816

commit dee72af
Author: Vertex MG Team <[email protected]>
Date:   Tue Jan 14 08:37:51 2025 -0800

    Enable dedicate endpoint for Prompt Guard deployment

    PiperOrigin-RevId: 715395374

commit ea45e3d
Author: Vertex MG Team <[email protected]>
Date:   Mon Jan 13 21:41:01 2025 -0800

    Add usage tracking labels to all the finetuning notebook

    PiperOrigin-RevId: 715229459

commit 9ed8896
Author: Vertex MG Team <[email protected]>
Date:   Mon Jan 13 20:29:57 2025 -0800

    vLLM supports GPU HBM + host memory prefix kv caching

    PiperOrigin-RevId: 715213919

commit 2709fcd
Author: Vertex MG Team <[email protected]>
Date:   Mon Jan 13 16:54:10 2025 -0800

    Fix the error when result is list, make it works for both dictionary and list

    PiperOrigin-RevId: 715156151

commit a7c1b9a
Author: Vertex MG Team <[email protected]>
Date:   Mon Jan 13 14:01:08 2025 -0800

    Add publisherdb api call response check in case call fails

    PiperOrigin-RevId: 715099165

commit 952af22
Author: Vertex MG Team <[email protected]>
Date:   Mon Jan 13 11:16:26 2025 -0800

    Update license of the notebooks to 2025

    PiperOrigin-RevId: 715041456

commit d87b6b7
Author: Bhaskar Goyal <[email protected]>
Date:   Mon Jan 13 09:05:46 2025 -0800

    feat: Add Codestral (25.01) model to mistral docs. (GoogleCloudPlatform#3779)

commit 25b3364
Author: Dustin Luong <[email protected]>
Date:   Thu Jan 9 09:12:56 2025 -0800

    No public description

    PiperOrigin-RevId: 713694601

commit 9f80c3c
Author: Vertex MG Team <[email protected]>
Date:   Wed Jan 8 19:46:48 2025 -0800

    Hex-LLM supports prefix caching as a GA feature

    PiperOrigin-RevId: 713503181

commit 9f6ad84
Author: Mend Renovate <[email protected]>
Date:   Wed Jan 8 21:31:26 2025 +0100

    Update dependency pyupgrade to v3.19.1 (GoogleCloudPlatform#3757)

commit 64e9a4a
Author: Aiden010200 <[email protected]>
Date:   Thu Jan 9 04:30:52 2025 +0800

    Upload a ResNet predictor example (GoogleCloudPlatform#3765)

    This example uses aiplatform and torch library to provide a ResNet predictor.

commit 883e1e5
Author: Dustin Luong <[email protected]>
Date:   Tue Jan 7 14:19:49 2025 -0800

    Set system_labels in notebooks

    PiperOrigin-RevId: 713040032

commit 847a49f
Author: Vertex MG Team <[email protected]>
Date:   Fri Jan 3 06:08:58 2025 -0800

    Update docker images to avoid 'tags' KeyError while loading HF dataset

    PiperOrigin-RevId: 711729872

commit 0b9a450
Author: Vertex MG Team <[email protected]>
Date:   Fri Jan 3 00:24:52 2025 -0800

    MediaPipe Text Classification notebook

    PiperOrigin-RevId: 711656574

commit ed72474
Author: Vertex MG Team <[email protected]>
Date:   Thu Jan 2 00:52:37 2025 -0800

    pyTorch IMage Model notebook

    PiperOrigin-RevId: 711343028

commit 6b20a30
Author: Vertex MG Team <[email protected]>
Date:   Tue Dec 31 04:20:10 2024 -0800

    TFVision Image segmentation notebook

    PiperOrigin-RevId: 710942626

commit f50f5f7
Author: Vertex MG Team <[email protected]>
Date:   Mon Dec 30 20:29:52 2024 -0800

    mediapipe face stylizer notebook

    PiperOrigin-RevId: 710861510

commit 9621877
Author: alicechang0909 <[email protected]>
Date:   Mon Dec 30 10:59:13 2024 -0800

    Created using Colab (GoogleCloudPlatform#3763)

    * Created using Colab

    * Created using Colab

    * Add Featurestore Monitoring functionalities - Fix lint error in import

    * Fix lint with commands

    * Address comments

    * Update restart section to fix lint error.

    * fix: fix lint errors

    * Fix: Try to fix lint errors

    * Fix: try fix lint with python commands

    * fix: remove self link

    * fix: try submit from workbench

commit 425a5c9
Author: Vertex MG Team <[email protected]>
Date:   Mon Dec 30 06:09:08 2024 -0800

    Fix PIL issue `FreeTypeFont object has no attribute getsize`

    PiperOrigin-RevId: 710698919

commit 0b8b4b8
Author: Vertex MG Team <[email protected]>
Date:   Thu Dec 26 22:46:02 2024 -0800

    movinet action recognition notebook

    PiperOrigin-RevId: 709969002

commit f641e5d
Author: Vertex MG Team <[email protected]>
Date:   Wed Dec 18 16:06:45 2024 -0800

    Enable dedicate endpoint for pytorch llama3 deployment

    PiperOrigin-RevId: 707697043

commit d89b29f
Author: Vertex MG Team <[email protected]>
Date:   Tue Dec 17 22:37:20 2024 -0800

    Add usage labels to finetuning notebook

    PiperOrigin-RevId: 707402297

commit df49e18
Author: Vertex MG Team <[email protected]>
Date:   Tue Dec 17 18:19:12 2024 -0800

    Hex-LLM supports disaggregated serving as an experimental feature

    PiperOrigin-RevId: 707332453

commit dae9e79
Author: Changyu Zhu <[email protected]>
Date:   Tue Dec 17 18:01:57 2024 -0800

    Fix missing import in Llama 3 finetuning notebook

    PiperOrigin-RevId: 707326361

commit 16237de
Author: Changyu Zhu <[email protected]>
Date:   Tue Dec 17 15:32:19 2024 -0800

    Add fast deployment section to Llama 3.2 deployment notebook

    PiperOrigin-RevId: 707274649

commit f5e0394
Author: Vertex MG Team <[email protected]>
Date:   Mon Dec 16 17:27:41 2024 -0800

    Add H100 80 GB config for Llama 3

    PiperOrigin-RevId: 706888765

commit 775aa37
Author: Vertex MG Team <[email protected]>
Date:   Mon Dec 16 10:12:26 2024 -0800

    Update yolov8 model to use model and endpoint dictionary.

    PiperOrigin-RevId: 706750375

commit b43418b
Author: Vertex MG Team <[email protected]>
Date:   Mon Dec 16 05:26:05 2024 -0800

    mediapipe Object detection notebook bug fix and re-formatting

    PiperOrigin-RevId: 706673143

commit 20138d9
Author: Changyu Zhu <[email protected]>
Date:   Fri Dec 13 10:43:42 2024 -0800

    Add fast deployment section to Llama 3.1 deployment notebook

    PiperOrigin-RevId: 705931336

commit fadcb1e
Author: sageof6path <[email protected]>
Date:   Fri Dec 13 18:38:34 2024 +0530

    Peft docker fix (GoogleCloudPlatform#3751)

    * Updated util files to fix peft docker

    * fix imports

    * fix imports fileutils.py

commit f0ec2ac
Author: Vertex MG Team <[email protected]>
Date:   Thu Dec 12 15:52:49 2024 -0800

    Update Hex-LLM container URI.

    PiperOrigin-RevId: 705656670
  • Loading branch information
anubbhav-malhotra committed Jan 16, 2025
1 parent 5010284 commit eda8a08
Show file tree
Hide file tree
Showing 120 changed files with 4,343 additions and 2,257 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linter/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ipython
jupyter
nbconvert
black==24.10.0
pyupgrade==3.19.0
pyupgrade==3.19.1
isort==5.13.2
flake8==7.1.1
nbqa==1.9.1
Expand Down
34 changes: 34 additions & 0 deletions community-content/vertex_cpr_samples/torch/predictor_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import torch

from google.cloud.aiplatform.utils import prediction_utils
from google.cloud.aiplatform.prediction.predictor import Predictor
from torchvision.models import detection, resnet50, ResNet50_Weights
from typing import Dict, List

class ResNetPredictor(Predictor):

def __init__(self):
return

def load(self, artifacts_uri: str) -> None:
prediction_utils.download_model_artifacts(artifacts_uri)
if os.path.exists("model.pth.tar"):
self.model = detection.fasterrcnn_resnet50_fpn(pretrained=True)
stat_dic = torch.load("model.pth.tar")
self.model.load_state_dict(stat_dic['state_dict'])
else:
weights = ResNet50_Weights.DEFAULT
self.model = resnet50(weights=weights)
self.model.eval()

def preprocess(self, prediction_input: dict) -> torch.Tensor:
instances = prediction_input["instances"]
return torch.Tensor(instances)

@torch.inference_mode()
def predict(self, instances: torch.Tensor) -> List[str]:
return self._model(instances)

def postprocess(self, prediction_results: List[str]) -> Dict:
return {"predictions": prediction_results}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
Expand Down Expand Up @@ -305,6 +305,9 @@
" is_for_training=False,\n",
")\n",
"\n",
"# @markdown Set enable_prefix_cache_hbm to False if you don't want to use [prefix caching](https://cloud.google.com/vertex-ai/generative-ai/docs/open-models/use-hex-llm#prefix-caching).\n",
"enable_prefix_cache_hbm = True # @param {type:\"boolean\"}\n",
"\n",
"# Server parameters.\n",
"hbm_utilization_factor = 0.6 # A larger value improves throughput but gives higher risk of TPU out-of-memory errors with long prompts.\n",
"max_running_seqs = 256\n",
Expand All @@ -323,9 +326,11 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
" enable_prefix_cache_hbm: bool = False,\n",
" endpoint_id: str = \"\",\n",
" min_replica_count: int = 1,\n",
" max_replica_count: int = 1,\n",
Expand Down Expand Up @@ -364,6 +369,10 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
" if enable_prefix_cache_hbm and not disagg_topology:\n",
" hexllm_args.append(\"--enable_prefix_cache_hbm\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down Expand Up @@ -400,6 +409,9 @@
" service_account=service_account,\n",
" min_replica_count=min_replica_count,\n",
" max_replica_count=max_replica_count,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_codegemma_deployment_on_vertex.ipynb\",\n",
" },\n",
" )\n",
" return model, endpoint\n",
"\n",
Expand All @@ -414,6 +426,7 @@
" tensor_parallel_size=tensor_parallel_size,\n",
" hbm_utilization_factor=hbm_utilization_factor,\n",
" max_running_seqs=max_running_seqs,\n",
" enable_prefix_cache_hbm=enable_prefix_cache_hbm,\n",
" min_replica_count=min_replica_count,\n",
" max_replica_count=max_replica_count,\n",
" use_dedicated_endpoint=use_dedicated_endpoint,\n",
Expand Down Expand Up @@ -557,6 +570,8 @@
" enforce_eager: bool = False,\n",
" enable_lora: bool = False,\n",
" enable_chunked_prefill: bool = False,\n",
" enable_prefix_cache: bool = False,\n",
" host_prefix_kv_cache_utilization_target: float = 0.0,\n",
" max_loras: int = 1,\n",
" max_cpu_loras: int = 8,\n",
" use_dedicated_endpoint: bool = False,\n",
Expand Down Expand Up @@ -603,6 +618,14 @@
" if enable_chunked_prefill:\n",
" vllm_args.append(\"--enable-chunked-prefill\")\n",
"\n",
" if enable_prefix_cache:\n",
" vllm_args.append(\"--enable-prefix-caching\")\n",
"\n",
" if 0 < host_prefix_kv_cache_utilization_target < 1:\n",
" vllm_args.append(\n",
" f\"--host-prefix-kv-cache-utilization-target={host_prefix_kv_cache_utilization_target}\"\n",
" )\n",
"\n",
" if model_type:\n",
" vllm_args.append(f\"--model-type={model_type}\")\n",
"\n",
Expand Down Expand Up @@ -639,6 +662,9 @@
" accelerator_count=accelerator_count,\n",
" deploy_request_timeout=1800,\n",
" service_account=service_account,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_codegemma_deployment_on_vertex.ipynb\",\n",
" },\n",
" )\n",
" print(\"endpoint_name:\", endpoint.name)\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions notebooks/community/model_garden/model_garden_e5.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@
" accelerator_count=accelerator_count,\n",
" deploy_request_timeout=1800,\n",
" service_account=service_account,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_e5.ipynb\"\n",
" },\n",
" )\n",
" return model, endpoint"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
Expand Down Expand Up @@ -290,9 +290,11 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
" enable_prefix_cache_hbm: bool = False,\n",
" endpoint_id: str = \"\",\n",
" min_replica_count: int = 1,\n",
" max_replica_count: int = 1,\n",
Expand Down Expand Up @@ -331,6 +333,10 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
" if enable_prefix_cache_hbm and not disagg_topology:\n",
" hexllm_args.append(\"--enable_prefix_cache_hbm\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down Expand Up @@ -367,6 +373,9 @@
" service_account=service_account,\n",
" min_replica_count=min_replica_count,\n",
" max_replica_count=max_replica_count,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_gemma2_deployment_on_vertex.ipynb\",\n",
" },\n",
" )\n",
" return model, endpoint\n",
"\n",
Expand Down Expand Up @@ -582,6 +591,9 @@
" accelerator_count=accelerator_count,\n",
" deploy_request_timeout=1800,\n",
" service_account=service_account,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_gemma2_deployment_on_vertex.ipynb\",\n",
" },\n",
" )\n",
" return model, endpoint\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
Expand Down Expand Up @@ -290,6 +290,9 @@
" is_for_training=False,\n",
")\n",
"\n",
"# @markdown Set enable_prefix_cache_hbm to False if you don't want to use [prefix caching](https://cloud.google.com/vertex-ai/generative-ai/docs/open-models/use-hex-llm#prefix-caching).\n",
"enable_prefix_cache_hbm = True # @param {type:\"boolean\"}\n",
"\n",
"# Server parameters.\n",
"hbm_utilization_factor = 0.6 # A larger value improves throughput but gives higher risk of TPU out-of-memory errors with long prompts.\n",
"max_running_seqs = 256\n",
Expand All @@ -312,9 +315,11 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
" enable_prefix_cache_hbm: bool = False,\n",
" endpoint_id: str = \"\",\n",
" min_replica_count: int = 1,\n",
" max_replica_count: int = 1,\n",
Expand Down Expand Up @@ -353,6 +358,10 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
" if enable_prefix_cache_hbm and not disagg_topology:\n",
" hexllm_args.append(\"--enable_prefix_cache_hbm\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down Expand Up @@ -389,6 +398,9 @@
" service_account=service_account,\n",
" min_replica_count=min_replica_count,\n",
" max_replica_count=max_replica_count,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_gemma_deployment_on_vertex.ipynb\",\n",
" },\n",
" )\n",
" return model, endpoint\n",
"\n",
Expand All @@ -401,6 +413,7 @@
" machine_type=machine_type,\n",
" hbm_utilization_factor=hbm_utilization_factor,\n",
" max_running_seqs=max_running_seqs,\n",
" enable_prefix_cache_hbm=enable_prefix_cache_hbm,\n",
" min_replica_count=min_replica_count,\n",
" max_replica_count=max_replica_count,\n",
" use_dedicated_endpoint=use_dedicated_endpoint,\n",
Expand Down Expand Up @@ -654,6 +667,8 @@
" enforce_eager: bool = False,\n",
" enable_lora: bool = False,\n",
" enable_chunked_prefill: bool = False,\n",
" enable_prefix_cache: bool = False,\n",
" host_prefix_kv_cache_utilization_target: float = 0.0,\n",
" max_loras: int = 1,\n",
" max_cpu_loras: int = 8,\n",
" use_dedicated_endpoint: bool = False,\n",
Expand Down Expand Up @@ -700,6 +715,14 @@
" if enable_chunked_prefill:\n",
" vllm_args.append(\"--enable-chunked-prefill\")\n",
"\n",
" if enable_prefix_cache:\n",
" vllm_args.append(\"--enable-prefix-caching\")\n",
"\n",
" if 0 < host_prefix_kv_cache_utilization_target < 1:\n",
" vllm_args.append(\n",
" f\"--host-prefix-kv-cache-utilization-target={host_prefix_kv_cache_utilization_target}\"\n",
" )\n",
"\n",
" if model_type:\n",
" vllm_args.append(f\"--model-type={model_type}\")\n",
"\n",
Expand Down Expand Up @@ -736,6 +759,9 @@
" accelerator_count=accelerator_count,\n",
" deploy_request_timeout=1800,\n",
" service_account=service_account,\n",
" system_labels={\n",
" \"NOTEBOOK_NAME\": \"model_garden_gemma_deployment_on_vertex.ipynb\",\n",
" },\n",
" )\n",
" print(\"endpoint_name:\", endpoint.name)\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
Expand Down Expand Up @@ -180,10 +180,7 @@
"\n",
"! gcloud config set project $PROJECT_ID\n",
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/storage.admin\"\n",
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/aiplatform.user\"\n",
"\n",
"# The evaluation docker image.\n",
"EVAL_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-lm-evaluation-harness:20240320_0655_RC00\""
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/aiplatform.user\""
]
},
{
Expand Down Expand Up @@ -268,14 +265,18 @@
"if peft_output_dir_gcsfuse:\n",
" eval_command += [\n",
" \"--model_args\",\n",
" f\"pretrained={base_model_id},peft={peft_output_dir_gcsfuse},trust_remote_code=True,parallelize=True,device_map_option=auto\",\n",
" f\"pretrained={base_model_id},peft={peft_output_dir_gcsfuse},trust_remote_code=True,parallelize=True\",\n",
" ]\n",
"else:\n",
" eval_command += [\n",
" \"--model_args\",\n",
" f\"pretrained={base_model_id},trust_remote_code=True,parallelize=True,device_map_option=auto\",\n",
" f\"pretrained={base_model_id},trust_remote_code=True,parallelize=True\",\n",
" ]\n",
"\n",
"\n",
"# The evaluation docker image.\n",
"EVAL_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-lm-evaluation-harness:20241016_0934_RC00\"\n",
"\n",
"# Pass evaluation arguments and launch job.\n",
"worker_pool_specs = [\n",
" {\n",
Expand Down Expand Up @@ -324,15 +325,26 @@
"source": [
"# @title Fetch and print evaluation results\n",
"import json\n",
"import re\n",
"\n",
"from google.cloud import storage\n",
"\n",
"# Fetch evaluation results.\n",
"storage_client = storage.Client()\n",
"BUCKET_NAME = BUCKET_URI.split(\"gs://\")[1]\n",
"bucket = storage_client.get_bucket(BUCKET_NAME)\n",
"RESULT_FILE_PATH = eval_output_dir[len(BUCKET_URI) + 1 :] + \"/results.json\"\n",
"blob = bucket.blob(RESULT_FILE_PATH)\n",
"\n",
"blobs = [b.name for b in bucket.list_blobs()]\n",
"\n",
"result_file_path = None\n",
"for file_path in filter(re.compile(\".*/*.json\").match, blobs):\n",
" result_file_path = file_path\n",
" print(f\"Found result file: {file_path}\")\n",
"\n",
"if result_file_path is None:\n",
" raise ValueError(\"No result file found.\")\n",
"\n",
"blob = bucket.blob(result_file_path)\n",
"raw_result = blob.download_as_string()\n",
"\n",
"# Print evaluation results.\n",
Expand Down
Loading

0 comments on commit eda8a08

Please sign in to comment.