From 53a276ac8b6a77b9d75b4f0a67ee7cba938af0c3 Mon Sep 17 00:00:00 2001 From: hang lv Date: Tue, 30 May 2023 04:47:34 +0800 Subject: [PATCH] feat: test examples Signed-off-by: hang lv --- examples/client.py | 48 ++++++++++++++++++ examples/jax_single_layer.py | 4 +- examples/jax_single_layer_cli.py | 2 + examples/type_validation/client.py | 1 + tests/test_examples.py | 80 ++++++++++++++++++++++++++++++ 5 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 examples/client.py create mode 100644 tests/test_examples.py diff --git a/examples/client.py b/examples/client.py new file mode 100644 index 00000000..8b3f04f8 --- /dev/null +++ b/examples/client.py @@ -0,0 +1,48 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example: Common Client for Test""" + +import json +import sys +from http import HTTPStatus + +import httpx + +req = { + "echo": {"time": 1.5}, + "plasma_shm_ipc": {"size": 100}, + "redis_shm_ipc": {"size": 100}, +} + + +def post(data): + """Post request to server""" + resp = httpx.post("http://127.0.0.1:8000/inference", content=data) + if resp.status_code == HTTPStatus.OK: + print(f"OK: {resp.json()}") + else: + print(f"err[{resp.status_code}] {resp.text}") + sys.exit(1) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Please specify a shm to run: plasma or redis") + sys.exit(1) + + k = sys.argv[1] + content = req[k] + if k not in ["msgpack"]: + content = json.dumps(content) + post(content) diff --git a/examples/jax_single_layer.py b/examples/jax_single_layer.py index 4c284c6c..02e91a33 100644 --- a/examples/jax_single_layer.py +++ b/examples/jax_single_layer.py @@ -56,7 +56,7 @@ def __init__(self): else: self.batch_forward = self._batch_forward - def _forward(self, x_single: jnp.ndarray) -> jnp.ndarray: + def _forward(self, x_single: jnp.ndarray) -> jnp.ndarray: # type: ignore chex.assert_rank([x_single], [1]) h_1 = jnp.dot(self._layer1_w.T, x_single) + self._layer1_b a_1 = jax.nn.relu(h_1) @@ -64,7 +64,7 @@ def _forward(self, x_single: jnp.ndarray) -> jnp.ndarray: o_2 = jax.nn.softmax(h_2) return jnp.argmax(o_2, axis=-1) - def _batch_forward(self, x_batch: jnp.ndarray) -> jnp.ndarray: + def _batch_forward(self, x_batch: jnp.ndarray) -> jnp.ndarray: # type: ignore chex.assert_rank([x_batch], [2]) return jax.vmap(self._forward)(x_batch) diff --git a/examples/jax_single_layer_cli.py b/examples/jax_single_layer_cli.py index 6139f0a1..484f8a06 100644 --- a/examples/jax_single_layer_cli.py +++ b/examples/jax_single_layer_cli.py @@ -14,6 +14,7 @@ """Example: Client of the Jax server.""" import random +import sys import httpx @@ -28,3 +29,4 @@ print(prediction.json()) else: print(prediction.status_code, prediction.json()) + sys.exit(1) diff --git a/examples/type_validation/client.py b/examples/type_validation/client.py index 9be4261e..1bf7b6db 100644 --- a/examples/type_validation/client.py +++ b/examples/type_validation/client.py @@ -27,3 +27,4 @@ print(f"OK: {msgpack.unpackb(resp.content)}") else: print(f"err[{resp.status_code}] {resp.text}") + exit(1) diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 00000000..83d9f2fe --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,80 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shlex +import subprocess +import time + +import httpx +import pytest + +from tests.utils import wait_for_port_open + +TIMEOUT_SERVICE_PORT = 8000 + + +@pytest.fixture +def example_server(request): + name = request.param + filepath = os.path.join("examples", name) + service = subprocess.Popen( + shlex.split(f"python -u {filepath}.py --port {TIMEOUT_SERVICE_PORT}") + ) + assert wait_for_port_open( + port=TIMEOUT_SERVICE_PORT, timeout=5 + ), "service failed to start" + yield name + service.terminate() + time.sleep(2) # wait for service to stop + + +@pytest.mark.parametrize( + "example_server,example_client", + [ + pytest.param( + "type_validation/server", + "type_validation/client", + id="type_validation", + ), + pytest.param( + "echo", + "client", + id="echo", + ), + pytest.param( + "redis_shm_ipc", + "client", + id="redis_shm_ipc", + ), + pytest.param( + "plasma_shm_ipc", + "client", + id="plasma_shm_ipc", + ), + pytest.param( + "jax_single_layer", + "jax_single_layer_cli", + id="jax_single_layer", + ), + ], + indirect=["example_server"], +) +def test_forward_timeout(example_server, example_client: str): + filepath = os.path.join("examples", example_client) + service = subprocess.Popen(shlex.split(f"python -u {filepath}.py {example_server}")) + stdout, stderr = service.communicate() + code = service.returncode + assert code == 0, (code, stdout, stderr)