Skip to content

Commit

Permalink
changes based on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
priyanka-ganesha committed Oct 2, 2023
1 parent 54caff6 commit c585140
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 26 deletions.
4 changes: 3 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,7 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds
# Use iota operator in Embed
use_iota_embed: False

#Monitoring parameters
#Monitoring parameters - Export in-workload metrics to Cloud monitoring
enable_cloud_monitoring: True
cloud_monitoring_dashboard: "https://pantheon.corp.google.com/monitoring/dashboards?project="
cloud_zone: "" # zone name for cloud jobs - used for cloud metrics emitting
4 changes: 4 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools

import max_logging
import monitoring_api

import numpy as np
import jax
Expand Down Expand Up @@ -227,6 +228,9 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
state = unbox_logicallypartioned_trainstate(state)
return state, state_mesh_annotations

def register_train_metrics(metric_name, metric_description):
monitoring_api.create_custom_metric(metric_name, metric_description)


# Learning Rate Schedule
# -----------------------------------------------------------------------------
Expand Down
62 changes: 47 additions & 15 deletions MaxText/monitoring_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import time
import os

import max_logging

def get_metadata(project_id, zone, instance_id):
"""
Fetches metadata
Expand Down Expand Up @@ -61,7 +63,7 @@ def create_custom_metric(metric_name, description):
return response


def write_time_series_step(metric_name, monitoring_enabled, step=1):
def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1):
"""
Writes a time series object for a specified custom metric
Expand All @@ -71,7 +73,7 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1):
step
"""

zone = get_zone()
zone = pyconfig.config.cloud_zone
project_id = get_project()

if not monitoring_enabled:
Expand All @@ -96,7 +98,7 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1):
event_time = time.strftime(
"%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc)
)
print(
max_logging.log(
"Emitting metric ",
metric_name,
" for step = ",
Expand Down Expand Up @@ -125,14 +127,49 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1):
]

client.create_time_series(name=project_name, time_series=[series], metadata=get_metadata(project_id, zone, instance_id))
print(
dashboard_link = pyconfig.config.cloud_monitoring_dashboard+project_name
max_logging.log(
"Time series added for step",
step,
"and instance_id ",
instance_id,
" and zone ",
zone,
"\nView dashboards or use metrics: ",
dashboard_link,
)
return [series]

def get_time_series_step_data(metric_name):
"""
Retrieves time series data
Args:
metric_name
"""
project_id = get_project()
project_name = f"projects/{project_id}"
instance_name = os.uname().nodename

mql = """
fetch gce_instance
| metric 'custom.googleapis.com/{metric_name}'
| filter (metric.worker == '{worker_id}')
| every 1m
| within -1d, 1d # one day, starting 1 day ago
"""

client = get_query_service_client()
request = monitoring_v3.QueryTimeSeriesRequest({
"name": project_name,
"query": mql.format(
metric_name=metric_name, worker_id=instance_name
),
})

result = client.query_time_series(request)
return result.time_series_data


def get_instance_id(project_id, zone):
"""
Expand All @@ -157,17 +194,6 @@ def get_project():
sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project <project>'")
return project_outputs[-1]

def get_zone():
"""
Fetches zone in use
"""
subprocess.run("gcloud config set compute/zone us-central2-b")
completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True)
zone_outputs = completed_command.stdout.decode().strip().split('\n')
if len(zone_outputs) < 1 or zone_outputs[-1]=='':
sys.exit("You must specify the zone in the ZONE flag or set it with 'gcloud config set compute/zone <zone>'")
return zone_outputs[-1]

def get_compute_instances_client():
"""
Fetches cloud compute instances client
Expand All @@ -179,3 +205,9 @@ def get_metrics_service_client():
Fetches cloud monitoring API client
"""
return monitoring_v3.MetricServiceClient()

def get_query_service_client():
"""
Fetches cloud monitoring query service client
"""
return monitoring_v3.QueryServiceClient()
3 changes: 3 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import math
import os
import subprocess
import sys
import yaml

Expand Down Expand Up @@ -89,6 +90,8 @@ def user_init(raw_keys):
raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default
run_name = raw_keys["run_name"]
assert run_name, "Erroring out, need a real run_name"
assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring']),
"You must provide cloud_zone if cloud monitoring is enabled")
base_output_directory = raw_keys["base_output_directory"]
validate_gcs_bucket_name(base_output_directory, "base_output_directory")
dataset_path = raw_keys["dataset_path"]
Expand Down
39 changes: 39 additions & 0 deletions MaxText/tests/cloud_monitoring_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Tests for Cloud Monitoring API """
import sys
import jax
import unittest

import monitoring_api
import pyconfig

jax.config.update('jax_platform_name', 'cpu')

class CloudMonitoringTests(unittest.TestCase):
"""Test for writing time series step using monitoring_api.py"""
def test_write_time_series_step(self):
pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', cloud_zone='us-central2-b')
monitoring_api.create_custom_metric('test_metric', "This is an example metric")
create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1)
query_time_series_result = monitoring_api.get_time_series_step_data('test_metric')
self.assertEqual(create_time_series_result, query_time_series_result)


if __name__ == '__main__':
unittest.main()

16 changes: 8 additions & 8 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ def train_loop(config, state=None):
monitoring_enabled = config.enable_cloud_monitoring

if monitoring_enabled:
monitoring_api.create_custom_metric('checkpointing_init_start', "Checkpointing Initialization Start")
monitoring_api.create_custom_metric('checkpointing_init_end', "Checkpointing Initialization End")
monitoring_api.create_custom_metric('checkpoint_test_run_start', "Checkpointing Test Run Start")
monitoring_api.create_custom_metric('checkpoint_test_run_end', "Checkpointing Test Run End")
max_utils.register_train_metrics('checkpointint_init_start', "Checkpointing Initialization Start")
max_utils.register_train_metrics('checkpointing_init_end', "Checkpointing Initialization End")
max_utils.register_train_metrics('checkpoint_test_run_start', "Checkpointing Test Run Start")
max_utils.register_train_metrics('checkpoint_test_run_end', "Checkpointing Test Run End")

monitoring_api.write_time_series_step('checkpoint_test_run_start', 0, monitoring_enabled)
monitoring_api.write_time_series_step('checkpoint_test_run_start', monitoring_enabled, pyconfig, 0)

writer = SummaryWriter(config.tensorboard_dir)

monitoring_api.write_time_series_step('checkpointing_init_start', 1, monitoring_enabled)
monitoring_api.write_time_series_step('checkpointing_init_start', monitoring_enabled, pyconfig, 1)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
config.checkpoint_dir,
Expand All @@ -234,7 +234,7 @@ def train_loop(config, state=None):
config.save_period,
)

monitoring_api.write_time_series_step('checkpointing_init_end', 1, monitoring_enabled)
monitoring_api.write_time_series_step('checkpointing_init_end', monitoring_enabled, pyconfig, 1)

# Initial PRNG Keys
init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2)
Expand Down Expand Up @@ -318,7 +318,7 @@ def train_loop(config, state=None):
if step == 0:
max_utils.activate_profiler(config)

monitoring_api.write_time_series_step('checkpoint_test_run_end', config.steps, monitoring_enabled)
monitoring_api.write_time_series_step('checkpoint_test_run_end', monitoring_enabled, pyconfig, config.steps)
max_utils.deactivate_profiler(config)
writer.close()
return state
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ absl-py
argparse
cloud-tpu-diagnostics
datetime
google-cloud-compute==1.6.1
google-cloud-monitoring==2.11.3
google-cloud-compute
google-cloud-monitoring
google-cloud-storage
flax
ml-collections
Expand Down

0 comments on commit c585140

Please sign in to comment.