diff --git a/applications/rag/tests/test_frontend.py b/applications/rag/tests/test_frontend.py index 32cc2fe8c..13cc05e5c 100644 --- a/applications/rag/tests/test_frontend.py +++ b/applications/rag/tests/test_frontend.py @@ -1,8 +1,20 @@ import sys import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry def test_frontend_up(rag_frontend_url): - r = requests.get(rag_frontend_url) + retry_strategy = Retry( + total=5, # Total number of retries + backoff_factor=1, # Waits 1 second between retries, then 2s, 4s, 8s... + status_forcelist=[429, 500, 502, 503, 504], # Status codes to retry on + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + r = session.get(rag_frontend_url) r.raise_for_status() print("Rag frontend is up.") diff --git a/applications/rag/tests/test_rag.py b/applications/rag/tests/test_rag.py index d7da3a0e2..d24876e67 100644 --- a/applications/rag/tests/test_rag.py +++ b/applications/rag/tests/test_rag.py @@ -1,6 +1,8 @@ import json import sys import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry def test_prompts(prompt_url): testcases = [ @@ -46,21 +48,42 @@ def test_prompts(prompt_url): json_payload = json.dumps(data) headers = {'Content-Type': 'application/json'} - response = requests.post(prompt_url, data=json_payload, headers=headers) - response.raise_for_status() - - response = response.json() - context = response['response']['context'] - text = response['response']['text'] - user_prompt = response['response']['user_prompt'] - - print(f"Reply: {text}") - - assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" - assert context == expected_context, f"unexpected context: {context} != {expected_context}" - - for substring in expected_substrings: - assert substring in text, f"substring {substring} not in response:\n {text}" + # Define a retry strategy + retry_strategy = Retry( + total=5, # Total number of retries + backoff_factor=1, # Waits 1 second between retries, then 2s, 4s, 8s... + status_forcelist=[429, 500, 502, 503, 504], # Status codes to retry on + ) + + # Mount the retry strategy to the session + adapter = HTTPAdapter(max_retries=retry_strategy) + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + try: + response = session.post(prompt_url, data=json_payload, headers=headers) + response.raise_for_status() + + response = response.json() + context = response['response']['context'] + text = response['response']['text'] + user_prompt = response['response']['user_prompt'] + + print(f"Reply: {text}") + + assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" + assert context == expected_context, f"unexpected context: {context} != {expected_context}" + + for substring in expected_substrings: + assert substring in text, f"substring {substring} not in response:\n {text}" + + except requests.exceptions.ConnectionError as e: + print(f"Error connecting to the server: {e}") + except requests.exceptions.HTTPError as e: + print(f"HTTP error occurred: {e}") + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") def test_prompts_nlp(prompt_url): testcases = [ @@ -101,21 +124,42 @@ def test_prompts_nlp(prompt_url): json_payload = json.dumps(data) headers = {'Content-Type': 'application/json'} - response = requests.post(prompt_url, data=json_payload, headers=headers) - response.raise_for_status() - - response = response.json() - context = response['response']['context'] - text = response['response']['text'] - user_prompt = response['response']['user_prompt'] - - print(f"Reply: {text}") - - assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" - assert context == expected_context, f"unexpected context: {context} != {expected_context}" - - for substring in expected_substrings: - assert substring in text, f"substring {substring} not in response:\n {text}" + # Define a retry strategy + retry_strategy = Retry( + total=5, # Total number of retries + backoff_factor=1, # Waits 1 second between retries, then 2s, 4s, 8s... + status_forcelist=[429, 500, 502, 503, 504], # Status codes to retry on + ) + + # Mount the retry strategy to the session + adapter = HTTPAdapter(max_retries=retry_strategy) + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + try: + response = session.post(prompt_url, data=json_payload, headers=headers) + response.raise_for_status() + + response = response.json() + context = response['response']['context'] + text = response['response']['text'] + user_prompt = response['response']['user_prompt'] + + print(f"Reply: {text}") + + assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" + assert context == expected_context, f"unexpected context: {context} != {expected_context}" + + for substring in expected_substrings: + assert substring in text, f"substring {substring} not in response:\n {text}" + + except requests.exceptions.ConnectionError as e: + print(f"Error connecting to the server: {e}") + except requests.exceptions.HTTPError as e: + print(f"HTTP error occurred: {e}") + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") def test_prompts_dlp(prompt_url): testcases = [ @@ -139,22 +183,44 @@ def test_prompts_dlp(prompt_url): data = {"prompt": prompt, "inspectTemplate": inspectTemplate, "deidentifyTemplate": deidentifyTemplate} json_payload = json.dumps(data) - headers = {'Content-Type': 'application/json'} - response = requests.post(prompt_url, data=json_payload, headers=headers) - response.raise_for_status() - - response = response.json() - context = response['response']['context'] - text = response['response']['text'] - user_prompt = response['response']['user_prompt'] - - print(f"Reply: {text}") + # Define a retry strategy + retry_strategy = Retry( + total=5, # Total number of retries + backoff_factor=1, # Waits 1 second between retries, then 2s, 4s, 8s... + status_forcelist=[429, 500, 502, 503, 504], # Status codes to retry on + ) - assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" - assert context == expected_context, f"unexpected context: {context} != {expected_context}" + # Mount the retry strategy to the session + adapter = HTTPAdapter(max_retries=retry_strategy) + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) - for substring in expected_substrings: - assert substring in text, f"substring {substring} not in response:\n {text}" + headers = {'Content-Type': 'application/json'} + + try: + response = session.post(prompt_url, data=json_payload, headers=headers) + response.raise_for_status() + + response = response.json() + context = response['response']['context'] + text = response['response']['text'] + user_prompt = response['response']['user_prompt'] + + print(f"Reply: {text}") + + assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" + assert context == expected_context, f"unexpected context: {context} != {expected_context}" + + for substring in expected_substrings: + assert substring in text, f"substring {substring} not in response:\n {text}" + + except requests.exceptions.ConnectionError as e: + print(f"Error connecting to the server: {e}") + except requests.exceptions.HTTPError as e: + print(f"HTTP error occurred: {e}") + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") prompt_url = sys.argv[1] test_prompts(prompt_url)