Skip to content

Commit

Permalink
Replace tf.io.gfile with epath.Path.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686570636
  • Loading branch information
cliveverghese authored and copybara-github committed Oct 22, 2024
1 parent 6a1c3a6 commit 7e0f1a0
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions plugin/tensorboard_plugin_profile/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,15 @@
import re
import threading

from etils import epath
import six
import tensorflow.compat.v2 as tf
from werkzeug import wrappers

from tensorboard.backend.event_processing import plugin_asset_util
from tensorboard.context import RequestContext
from tensorboard.plugins import base_plugin
from tensorflow.python.profiler import profiler_client # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.profiler import profiler_v2 as profiler # pylint: disable=g-direct-tensorflow-import
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert

tf.enable_v2_behavior()

logger = logging.getLogger('tensorboard')

Expand Down Expand Up @@ -206,7 +203,7 @@ def _get_tools(filenames, profile_run_dir):
found = set()
xplane_filenames = []
for name in filenames:
_, tool = _parse_filename(name)
_, tool = _parse_filename(os.fspath(name))
if tool == 'xplane':
xplane_filenames.append(os.path.join(profile_run_dir, name))
continue
Expand Down Expand Up @@ -518,10 +515,11 @@ def _run_host_impl(self, run, run_dir, tool):
tool_pattern = make_filename('*', tool)
filenames = []
try:
filenames = tf.io.gfile.glob(os.path.join(run_dir, tool_pattern))
except tf.errors.OpError as e:
path = epath.Path(run_dir)
filenames = path.glob(tool_pattern)
except RuntimeError as e:
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)
filenames = [os.path.basename(f) for f in filenames]
filenames = [os.fspath(os.path.basename(f)) for f in filenames]

return filenames_to_hosts(filenames, tool)

Expand Down Expand Up @@ -623,8 +621,9 @@ def data_impl(self, request):
if host == ALL_HOSTS:
file_pattern = make_filename('*', 'xplane')
try:
asset_paths = tf.io.gfile.glob(os.path.join(run_dir, file_pattern))
except tf.errors.OpError as e:
path = epath.Path(run_dir)
asset_paths = path.glob(file_pattern)
except RuntimeError as e:
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir,
e)
raise IOError(
Expand All @@ -648,11 +647,9 @@ def data_impl(self, request):

raw_data = None
try:
with tf.io.gfile.GFile(asset_path, 'rb') as f:
raw_data = f.read()
except tf.errors.NotFoundError:
logger.warning('Asset path %s not found', asset_path)
except tf.errors.OpError as e:
path = epath.Path(asset_path)
raw_data = path.read_bytes()
except RuntimeError as e:
logger.warning("Couldn't read asset path: %s, OpError %s", asset_path, e)

if raw_data is None:
Expand Down Expand Up @@ -683,6 +680,9 @@ def capture_route(self, request):

def capture_route_impl(self, request):
"""Runs the client trace for capturing profiling information."""
from tensorflow.python.profiler import profiler_client # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
from tensorflow.python.profiler import profiler_v2 as profiler # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top

service_addr = request.args.get('service_addr')
duration = int(request.args.get('duration', '1000'))
is_tpu_name = request.args.get('is_tpu_name') == 'true'
Expand All @@ -703,11 +703,15 @@ def capture_route_impl(self, request):

if is_tpu_name:
try:
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
service_addr)
import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top

tf.enable_v2_behavior()
tpu_cluster_resolver = (
tf.distribute.cluster_resolver.TPUClusterResolver(service_addr)
)
master_grpc_addr = tpu_cluster_resolver.get_master()
except (ImportError, RuntimeError) as err:
return respond({'error': err.message}, 'application/json', code=200)
return respond({'error': err}, 'application/json', code=200)
except (ValueError, TypeError):
return respond(
{'error': 'no TPUs with the specified names exist.'},
Expand Down Expand Up @@ -743,12 +747,6 @@ def capture_route_impl(self, request):
{'result': 'Capture profile successfully. Please refresh.'},
'application/json',
)
except tf.errors.UnavailableError:
return respond(
{'error': 'empty trace result.'},
'application/json',
code=200,
)
except Exception as e: # pylint: disable=broad-except
return respond(
{'error': str(e)},
Expand Down Expand Up @@ -798,7 +796,7 @@ def _run_dir(self, run):
if not tb_run_name:
tb_run_name = '.'
tb_run_directory = _tb_run_directory(self.logdir, tb_run_name)
if not tf.io.gfile.isdir(tb_run_directory):
if not epath.Path(tb_run_directory).is_dir():
raise RuntimeError('No matching run directory for run %s' % run)

plugin_directory = plugin_asset_util.PluginDirectory(
Expand Down Expand Up @@ -861,7 +859,7 @@ def generate_runs(self):
# backwards compatible with previously profile plugin behavior. Note that we
# check if logdir is a directory to handle case where it's actually a
# multipart directory spec, which this plugin does not support.
if '.' not in tb_runs and tf.io.gfile.isdir(self.logdir):
if '.' not in tb_runs and epath.Path(self.logdir).is_dir():
tb_runs.append('.')
tb_run_names_to_dirs = {
run: _tb_run_directory(self.logdir, run) for run in tb_runs
Expand All @@ -880,17 +878,17 @@ def generate_runs(self):
else:
frontend_run = os.path.join(tb_run_name, profile_run)
profile_run_dir = os.path.join(tb_plugin_dir, profile_run)
if tf.io.gfile.isdir(profile_run_dir):
if epath.Path(profile_run_dir).is_dir():
self._run_to_profile_run_dir[frontend_run] = profile_run_dir
yield frontend_run

def generate_tools_of_run(self, run):
"""Generate a list of tools given a certain run."""
profile_run_dir = self._run_to_profile_run_dir[run]
if tf.io.gfile.isdir(profile_run_dir):
if epath.Path(profile_run_dir).is_dir():
try:
filenames = tf.io.gfile.listdir(profile_run_dir)
except tf.errors.NotFoundError as e:
filenames = epath.Path(profile_run_dir).iterdir()
except RuntimeError as e:
logger.warning('Cannot read asset directory: %s, NotFoundError %s',
profile_run_dir, e)
filenames = []
Expand Down

0 comments on commit 7e0f1a0

Please sign in to comment.