-
Notifications
You must be signed in to change notification settings - Fork 843
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Continous batching for single GPU LLM inference (#2628)
* Make test/pytest/test_handler.py run stand-alone * Refactor if else statement * First working poc for streaming inference with continuous batching * WIP: stopping criteria + caching * FE add continuousbatching * fmt * Fmt * Fix continuous batching PoC; remove batchDelay if jobs are being processed * Add model_config.yaml * Added ipynb for generate_next_token * Update notebook and move to right subfolder * Fix buffer underruns; wait until enough bytes are in for reading * Add bandaid for bug in our otf * Added test for otf protocol with context * Fix buffer underrun; handle batch quota of zero correctly * Initial implementation of prefill + decode without kv caching for now * adds missing __init__py files * WIP kv caching * Fixed kv cache; missing tuple; * Cleaned up streaming handler code * Added cache cleaning * clean up aggregator jobs forcontrol cmd * fmt * fix streaming handler test * Rename streaming test into continuous batching test * fmt * Enable gpu usage in continuous batching unit test * Add llama to stream notebook * skip pull mgmt job if jobs is not empty * set pollMgmtJobStatus init value as false * fmt * only take describe request if jobsrepo is empty * init job * Remove cont batching job if connection to client gets closed * Fix and reenable cached request logic * Fix linter error * fmt * remove llama2-13b stream_handler.py * revert otf * update maxDelay logic * replace size checking with isEmpty * Use handler section * Fix linter errors * Fix linter error in oft mesg handler * Fix linter error in test_otf_codec_protocol.py --------- Co-authored-by: lxning <[email protected]>
- Loading branch information
Showing
26 changed files
with
2,366 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
package org.pytorch.serve.wlm; | ||
|
||
import java.util.Map; | ||
import org.pytorch.serve.job.Job; | ||
import org.pytorch.serve.util.messages.BaseModelRequest; | ||
import org.pytorch.serve.util.messages.ModelInferenceRequest; | ||
import org.pytorch.serve.util.messages.ModelLoadModelRequest; | ||
import org.pytorch.serve.util.messages.ModelWorkerResponse; | ||
import org.pytorch.serve.util.messages.Predictions; | ||
import org.pytorch.serve.util.messages.RequestInput; | ||
import org.pytorch.serve.util.messages.WorkerCommands; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
public class ContinuousBatching extends BatchAggregator { | ||
private static final Logger logger = LoggerFactory.getLogger(ContinuousBatching.class); | ||
|
||
public ContinuousBatching(Model model) { | ||
super(model); | ||
} | ||
|
||
public BaseModelRequest getRequest(String threadName, WorkerState state) | ||
throws InterruptedException { | ||
int batchQuota = model.getBatchSize() - jobs.size(); | ||
|
||
ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName()); | ||
|
||
pollBatch(threadName, state, batchQuota); | ||
|
||
if (model.isUseJobTicket() && jobs.isEmpty()) { | ||
model.decNumJobTickets(); | ||
return req; | ||
} | ||
|
||
for (Job j : jobs.values()) { | ||
if (j.isControlCmd()) { | ||
if (jobs.size() > 1) { | ||
throw new IllegalStateException( | ||
"Received more than 1 control command. " | ||
+ "Control messages should be processed/retrieved one at a time."); | ||
} | ||
RequestInput input = j.getPayload(); | ||
int gpuId = -1; | ||
String gpu = input.getStringParameter("gpu"); | ||
if (gpu != null) { | ||
gpuId = Integer.parseInt(gpu); | ||
} | ||
return new ModelLoadModelRequest(model, gpuId); | ||
} else { | ||
if (j.getCmd() == WorkerCommands.STREAMPREDICT) { | ||
req.setCommand(WorkerCommands.STREAMPREDICT); | ||
} | ||
j.setScheduled(); | ||
req.addRequest(j.getPayload()); | ||
} | ||
} | ||
return req; | ||
} | ||
|
||
/** | ||
* @param message: a response of a batch inference requests | ||
* @return - true: either a non-stream response or last stream response is sent - false: a | ||
* stream response (not include the last stream) is sent | ||
*/ | ||
public boolean sendResponse(ModelWorkerResponse message) { | ||
// TODO: Handle prediction level code | ||
if (message.getCode() == 200) { | ||
if (message.getPredictions().isEmpty()) { | ||
// The jobs size is always 1 in the case control command | ||
for (Map.Entry<String, Job> j : jobs.entrySet()) { | ||
Job job = j.getValue(); | ||
if (job.isControlCmd()) { | ||
jobs.clear(); | ||
return true; | ||
} | ||
} | ||
} | ||
for (Predictions prediction : message.getPredictions()) { | ||
String jobId = prediction.getRequestId(); | ||
Job job = jobs.get(jobId); | ||
|
||
if (job == null) { | ||
throw new IllegalStateException( | ||
"Unexpected job in sendResponse() with 200 status code: " + jobId); | ||
} | ||
|
||
if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) { | ||
job.response( | ||
prediction.getResp(), | ||
prediction.getContentType(), | ||
prediction.getStatusCode(), | ||
prediction.getReasonPhrase(), | ||
prediction.getHeaders()); | ||
} else { | ||
logger.warn( | ||
"Drop response for inference request {} due to client timeout", | ||
job.getPayload().getRequestId()); | ||
} | ||
String streamNext = | ||
prediction | ||
.getHeaders() | ||
.get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT); | ||
if (streamNext != null && streamNext.equals("false")) { | ||
jobs.remove(jobId); | ||
} else if (!job.isOpen()) { | ||
jobs.remove(job.getJobId()); | ||
logger.info( | ||
"Connection to client got closed; Removing job: {}", | ||
job.getPayload().getRequestId()); | ||
} | ||
} | ||
} else { | ||
for (Map.Entry<String, Job> j : jobs.entrySet()) { | ||
if (j.getValue() == null) { | ||
throw new IllegalStateException( | ||
"Unexpected job in sendResponse() with non 200 status code: " | ||
+ j.getKey()); | ||
} | ||
Job job = j.getValue(); | ||
if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) { | ||
job.sendError(message.getCode(), message.getMessage()); | ||
} else { | ||
logger.warn( | ||
"Drop error response for inference request {} due to client timeout", | ||
job.getPayload().getRequestId()); | ||
} | ||
} | ||
jobs.clear(); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
private void pollBatch(String threadName, WorkerState state, int batchSize) | ||
throws InterruptedException { | ||
boolean pollMgmtJobStatus = false; | ||
if (jobs.isEmpty()) { | ||
pollMgmtJobStatus = | ||
model.pollMgmtJob( | ||
threadName, | ||
(state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE, | ||
jobs); | ||
} | ||
|
||
if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) { | ||
model.pollInferJob(jobs, batchSize); | ||
} | ||
} | ||
} |
Oops, something went wrong.