Skip to content

Commit

Permalink
add chatgpt-api-response-timeout-secs flag, set this to 20 mins in test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 20, 2024
1 parent 7dd7cca commit e49924e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
31 changes: 24 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ jobs:
- name: Run chatgpt api integration test
run: |
# Start first instance
DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 1200 > output1.log 2>&1 &
PID1=$!
# Start second instance
DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 1200 > output2.log 2>&1 &
PID2=$!
# Wait for discovery
Expand All @@ -96,22 +96,39 @@ jobs:
"messages": [{"role": "user", "content": "Placeholder to load model..."}],
"temperature": 0.7
}'
curl -s http://localhost:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3-8b",
"messages": [{"role": "user", "content": "Placeholder to load model..."}],
"temperature": 0.7
}'
response_1=$(curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3-8b",
"messages": [{"role": "user", "content": "Who was the king of pop?"}],
"temperature": 0.7
}')
echo "Response 1: $response_1"
response=$(curl -s http://localhost:8000/v1/chat/completions \
response_2=$(curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3-8b",
"messages": [{"role": "user", "content": "Who was the king of pop?"}],
"temperature": 0.7
}')
echo "Response: $response"
echo "Response 2: $response_2"
if ! echo "$response" | grep -q "Michael Jackson"; then
if ! echo "$response_1" | grep -q "Michael Jackson" || ! echo "$response_2" | grep -q "Michael Jackson"; then
echo "Test failed: Response does not contain 'Michael Jackson'"
echo "Response: $response"
echo "Response 1: $response_1"
echo "Response 2: $response_2"
exit 1
else
echo "Test passed: Response contains 'Michael Jackson'"
echo "Test passed: Response from both nodes contains 'Michael Jackson'"
fi
# Stop both instances
Expand Down
4 changes: 2 additions & 2 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def build_prompt(tokenizer, messages: List[Message]):


class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout_secs = 90
self.response_timeout_secs = response_timeout_secs
self.app = web.Application()
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
args = parser.parse_args()

Expand Down Expand Up @@ -57,7 +58,7 @@
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(node, inference_engine.__class__.__name__)
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)

node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))

Expand Down

0 comments on commit e49924e

Please sign in to comment.