Skip to content

Commit

Permalink
add stress test
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsmechtel committed Nov 12, 2024
1 parent 8ba2b3b commit e6c2d8f
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions test/stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,30 @@

SERVER_URL = "https://hypha.aicell.io"
WORKSPACE_NAME = "bioimageio-colab"
CLIENT_ID = "kubernetes"
SERVICE_ID = "sam"
SID = f"{WORKSPACE_NAME}/{CLIENT_ID}:{SERVICE_ID}"
SERVICE_ID = "microsam"
SID = f"{WORKSPACE_NAME}/{SERVICE_ID}"


async def run_client(client_id: int, image: np.ndarray, point_coordinates: list, point_labels: list):
print(f"Client {client_id} started")
client = await connect_to_server({"server_url": SERVER_URL, "method_timeout": 10})
segment_svc = await client.get_service(SID)
await segment_svc.segment(model_name="vit_b", image=image, point_coordinates=point_coordinates, point_labels=point_labels)
async def run_client(client_id: int, image: np.ndarray):
print(f"Client {client_id} started", flush=True)
client = await connect_to_server({"server_url": SERVER_URL, "method_timeout": 30})
segment_svc = await client.get_service(SID, {"mode": "random"})
await segment_svc.segment(model_name="vit_b", image=image, point_coordinates=[[128, 128]], point_labels=[1])
await asyncio.sleep(1)
await segment_svc.segment(model_name="vit_b", image=image, point_coordinates=[[20, 50]], point_labels=[1])
await asyncio.sleep(1)
await segment_svc.segment(model_name="vit_b", image=image, point_coordinates=[[180, 10]], point_labels=[1])
print(f"Client {client_id} finished", flush=True)


async def stress_test(num_clients: int):
image=np.random.rand(256, 256)
point_coordinates=[[128, 128]]
point_labels=[1]
tasks = [
run_client(client_id, image, point_coordinates, point_labels)
for client_id in range(num_clients)
]
tasks = []
for client_id in range(num_clients):
await asyncio.sleep(0.1)
tasks.append(run_client(client_id, image))
await asyncio.gather(*tasks)
print("All clients finished")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit e6c2d8f

Please sign in to comment.