From 3e374e85a9b02fc2f77129e9f649ad563223d9ef Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Mon, 2 Oct 2023 13:52:06 -0700 Subject: [PATCH] pylint --- MaxText/monitoring_api.py | 4 ++-- MaxText/pyconfig.py | 5 ++--- MaxText/tests/cloud_monitoring_test.py | 5 ++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/MaxText/monitoring_api.py b/MaxText/monitoring_api.py index d2c2785df..e244ac934 100644 --- a/MaxText/monitoring_api.py +++ b/MaxText/monitoring_api.py @@ -77,7 +77,7 @@ def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1): project_id = get_project() if not monitoring_enabled: - return + return [] client = get_metrics_service_client() project_name = f"projects/{project_id}" @@ -166,7 +166,7 @@ def get_time_series_step_data(metric_name): metric_name=metric_name, worker_id=instance_name ), }) - + result = client.query_time_series(request) return result.time_series_data diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 7232168f9..9659d95e2 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -19,7 +19,6 @@ import math import os -import subprocess import sys import yaml @@ -90,8 +89,8 @@ def user_init(raw_keys): raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default run_name = raw_keys["run_name"] assert run_name, "Erroring out, need a real run_name" - assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring']), - "You must provide cloud_zone if cloud monitoring is enabled") + assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring'])),\ + "You must provide cloud_zone if cloud monitoring is enabled" base_output_directory = raw_keys["base_output_directory"] validate_gcs_bucket_name(base_output_directory, "base_output_directory") dataset_path = raw_keys["dataset_path"] diff --git a/MaxText/tests/cloud_monitoring_test.py b/MaxText/tests/cloud_monitoring_test.py index 4e2318ab6..42e532f72 100644 --- a/MaxText/tests/cloud_monitoring_test.py +++ b/MaxText/tests/cloud_monitoring_test.py @@ -27,7 +27,10 @@ class CloudMonitoringTests(unittest.TestCase): """Test for writing time series step using monitoring_api.py""" def test_write_time_series_step(self): - pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', cloud_zone='us-central2-b') + pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], + logical_axis_rules = [['batch', 'data']], + data_sharding = ['data'], + cloud_zone='us-central2-b') monitoring_api.create_custom_metric('test_metric', "This is an example metric") create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1) query_time_series_result = monitoring_api.get_time_series_step_data('test_metric')