Skip to content

Commit

Permalink
Merge pull request #565 from yahoo/leewyang_release_port_tests
Browse files Browse the repository at this point in the history
unit tests for release_port argument
  • Loading branch information
leewyang authored May 13, 2021
2 parents 1a3c08e + 6e44771 commit 6fade96
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/test_TFCluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,36 @@ def _map_fun(args, ctx):
cluster.inference(rdd).count()
cluster.shutdown(grace_secs=5) # note: grace_secs must be larger than the time needed for post-feed actions

def test_port_released(self):
"""Test that temporary socket/port is released prior to invoking user map_fun."""
def _map_fun(args, ctx):
assert ctx.tmp_socket is None

cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
cluster.shutdown()

def test_port_unreleased(self):
"""Test that temporary socket/port is unreleased prior to invoking user map_fun."""
def _map_fun(args, ctx):
import socket
assert ctx.tmp_socket is not None
reserved_port = ctx.tmp_socket.getsockname()[1]

# socket bind to tmp port should fail
try:
my_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
my_sock.bind(('0.0.0.0', reserved_port))
assert False, "should never hit this assert statement"
except socket.error as e:
print(e)
assert True, "should raise an exception"

ctx.release_port()
assert ctx.tmp_socket is None

cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', release_port=False)
cluster.shutdown()


if __name__ == '__main__':
unittest.main()

0 comments on commit 6fade96

Please sign in to comment.