Skip to content

Commit

Permalink
Continous batching for single GPU LLM inference (#2628)
Browse files Browse the repository at this point in the history
* Make test/pytest/test_handler.py run stand-alone

* Refactor if else statement

* First working poc for streaming inference with continuous batching

* WIP: stopping criteria + caching

* FE add continuousbatching

* fmt

* Fmt

* Fix continuous batching PoC; remove batchDelay if jobs are being processed

* Add model_config.yaml

* Added ipynb for generate_next_token

* Update notebook and move to right subfolder

* Fix buffer underruns; wait until enough bytes are in for reading

* Add bandaid for bug in our otf

* Added test for otf protocol with context

* Fix buffer underrun; handle batch quota of zero correctly

* Initial implementation of prefill + decode without kv caching for now

* adds missing __init__py files

* WIP kv caching

* Fixed kv cache; missing tuple;

* Cleaned up streaming handler code

* Added cache cleaning

* clean up aggregator jobs forcontrol cmd

* fmt

* fix streaming handler test

* Rename streaming test into continuous batching test

* fmt

* Enable gpu usage in continuous batching unit test

* Add llama to stream notebook

* skip pull mgmt job if jobs is not empty

* set pollMgmtJobStatus init value as false

* fmt

* only take describe request if jobsrepo is empty

* init job

* Remove cont batching job if connection to client gets closed

* Fix and reenable cached request logic

* Fix linter error

* fmt

* remove llama2-13b stream_handler.py

* revert otf

* update maxDelay logic

* replace size checking with isEmpty

* Use handler section

* Fix linter errors

* Fix linter error in oft mesg handler

* Fix linter error in test_otf_codec_protocol.py

---------

Co-authored-by: lxning <[email protected]>
  • Loading branch information
mreso and lxning committed Oct 4, 2023
1 parent 28d9d99 commit 8d12993
Show file tree
Hide file tree
Showing 26 changed files with 2,366 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ public class ModelConfig {
*/
private boolean useJobTicket;

/** continuousBatching is a flag to enable continuous batching. */
private boolean continuousBatching;

public static ModelConfig build(Map<String, Object> yamlMap) {
ModelConfig modelConfig = new ModelConfig();
yamlMap.forEach(
Expand Down Expand Up @@ -158,6 +161,15 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
logger.warn("Invalid useJobTicket: {}, should be true or false", v);
}
break;
case "continuousBatching":
if (v instanceof Boolean) {
modelConfig.setContinuousBatching((boolean) v);
} else {
logger.warn(
"Invalid continuousBatching: {}, should be true or false",
v);
}
break;
default:
break;
}
Expand Down Expand Up @@ -313,6 +325,14 @@ public void setUseJobTicket(boolean useJobTicket) {
this.useJobTicket = useJobTicket;
}

public boolean isContinuousBatching() {
return continuousBatching;
}

public void setContinuousBatching(boolean continuousBatching) {
this.continuousBatching = continuousBatching;
}

public enum ParallelType {
NONE(""),
PP("pp"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,10 @@ public void sendError(int status, String error) {
.asRuntimeException());
}
}

@Override
public boolean isOpen() {
return ((ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver)
.isCancelled();
}
}
13 changes: 10 additions & 3 deletions frontend/server/src/main/java/org/pytorch/serve/job/Job.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ public WorkerCommands getCmd() {
}

public boolean isControlCmd() {
return !WorkerCommands.PREDICT.equals(cmd)
&& !WorkerCommands.STREAMPREDICT.equals(cmd)
&& !WorkerCommands.DESCRIBE.equals(cmd);
switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case DESCRIBE:
return false;
default:
return true;
}
}

public RequestInput getPayload() {
Expand All @@ -73,4 +78,6 @@ public abstract void response(
Map<String, String> responseHeaders);

public abstract void sendError(int status, String error);

public abstract boolean isOpen();
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
Expand Down Expand Up @@ -258,4 +259,10 @@ public CompletableFuture<byte[]> getResponsePromise() {
public void setResponsePromise(CompletableFuture<byte[]> responsePromise) {
this.responsePromise = responsePromise;
}

@Override
public boolean isOpen() {
Channel c = ctx.channel();
return c.isOpen();
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package org.pytorch.serve.util.codec;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -10,24 +11,27 @@ public final class CodecUtils {

public static final int END = -1;
public static final int BUFFER_UNDER_RUN = -3;
public static final long TIMEOUT_IN_MILLIS = 100;

private CodecUtils() {}

public static int readLength(ByteBuf byteBuf, int maxLength) {
int size = byteBuf.readableBytes();

if (size < 4) {
return BUFFER_UNDER_RUN;
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}

int len = byteBuf.readInt();
if (len > maxLength) {
throw new CorruptedFrameException(
throw new TooLongFrameException(
"Message size exceed limit: "
+ len
+ "\nConsider increasing the 'max_response_size' in 'config.properties' to fix.");
}

if (len > byteBuf.readableBytes()) {
return BUFFER_UNDER_RUN;
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}
return len;
}
Expand All @@ -38,7 +42,7 @@ public static String readString(ByteBuf byteBuf, int len) {

public static byte[] read(ByteBuf in, int len) {
if (len < 0) {
throw new CorruptedFrameException("Invalid message size: " + len);
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}

byte[] buf = new byte[len];
Expand All @@ -49,9 +53,19 @@ public static byte[] read(ByteBuf in, int len) {
public static Map<String, String> readMap(ByteBuf in, int len) {
HashMap<String, String> ret = new HashMap<>();
for (; len > 0; len--) {
int l = readLength(in, in.readableBytes());
int l =
readLength(
in,
6500000); // We replace len here with 6500000 as a workaround before we
// can fix the whole otf. Basically, were mixing up bytes
// (expected by readLength) and number of entries (given to
// readMap). If we only have a small number of entries our
// values in the map are not allowed to be very big as we
// compare the given number of entries with the byte size
// we're expecting after reading the length of the next
// message.
String key = readString(in, l);
l = readLength(in, in.readableBytes());
l = readLength(in, 6500000);
String val = readString(in, l);
ret.put(key, val);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
out.writeInt(buf.length);
out.writeBytes(buf);

if (req.isCached()) {
out.writeInt(-1); // End of List
out.writeInt(-1); // End of List
return;
}

for (Map.Entry<String, String> entry : req.getHeaders().entrySet()) {
encodeField(entry.getKey(), out);
encodeField(entry.getValue(), out);
Expand All @@ -86,6 +92,7 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
encodeParameter(input, out);
}
out.writeInt(-1); // End of List
req.setCached(true);
}

private void encodeParameter(InputParameter parameter, ByteBuf out) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException;
import java.util.ArrayList;
import java.util.List;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
Expand Down Expand Up @@ -82,6 +83,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
resp.setPredictions(predictions);
out.add(resp);
completed = true;
} catch (NotEnoughDataDecoderException e) {
} finally {
if (!completed) {
in.resetReaderIndex();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class RequestInput {
private Map<String, String> headers;
private List<InputParameter> parameters;
private long clientExpireTS;
private boolean cached;

public RequestInput(String requestId) {
this.requestId = requestId;
Expand Down Expand Up @@ -71,4 +72,12 @@ public void setClientExpireTS(long clientTimeoutInMills) {
this.clientExpireTS = System.currentTimeMillis() + clientTimeoutInMills;
}
}

public boolean isCached() {
return cached;
}

public void setCached(boolean cached) {
this.cached = cached;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ public class BatchAggregator {

private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class);

private Model model;
private Map<String, Job> jobs;
protected Model model;
protected Map<String, Job> jobs;

public BatchAggregator() {}

public BatchAggregator(Model model) {
this.model = model;
Expand Down Expand Up @@ -171,4 +173,10 @@ public void sendError(BaseModelRequest message, String error, int status) {
}
jobs.clear();
}

public void cleanJobs() {
if (jobs != null) {
jobs.clear();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.pytorch.serve.wlm;

import java.util.Map;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.ModelInferenceRequest;
import org.pytorch.serve.util.messages.ModelLoadModelRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.util.messages.Predictions;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ContinuousBatching extends BatchAggregator {
private static final Logger logger = LoggerFactory.getLogger(ContinuousBatching.class);

public ContinuousBatching(Model model) {
super(model);
}

public BaseModelRequest getRequest(String threadName, WorkerState state)
throws InterruptedException {
int batchQuota = model.getBatchSize() - jobs.size();

ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName());

pollBatch(threadName, state, batchQuota);

if (model.isUseJobTicket() && jobs.isEmpty()) {
model.decNumJobTickets();
return req;
}

for (Job j : jobs.values()) {
if (j.isControlCmd()) {
if (jobs.size() > 1) {
throw new IllegalStateException(
"Received more than 1 control command. "
+ "Control messages should be processed/retrieved one at a time.");
}
RequestInput input = j.getPayload();
int gpuId = -1;
String gpu = input.getStringParameter("gpu");
if (gpu != null) {
gpuId = Integer.parseInt(gpu);
}
return new ModelLoadModelRequest(model, gpuId);
} else {
if (j.getCmd() == WorkerCommands.STREAMPREDICT) {
req.setCommand(WorkerCommands.STREAMPREDICT);
}
j.setScheduled();
req.addRequest(j.getPayload());
}
}
return req;
}

/**
* @param message: a response of a batch inference requests
* @return - true: either a non-stream response or last stream response is sent - false: a
* stream response (not include the last stream) is sent
*/
public boolean sendResponse(ModelWorkerResponse message) {
// TODO: Handle prediction level code
if (message.getCode() == 200) {
if (message.getPredictions().isEmpty()) {
// The jobs size is always 1 in the case control command
for (Map.Entry<String, Job> j : jobs.entrySet()) {
Job job = j.getValue();
if (job.isControlCmd()) {
jobs.clear();
return true;
}
}
}
for (Predictions prediction : message.getPredictions()) {
String jobId = prediction.getRequestId();
Job job = jobs.get(jobId);

if (job == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with 200 status code: " + jobId);
}

if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
job.response(
prediction.getResp(),
prediction.getContentType(),
prediction.getStatusCode(),
prediction.getReasonPhrase(),
prediction.getHeaders());
} else {
logger.warn(
"Drop response for inference request {} due to client timeout",
job.getPayload().getRequestId());
}
String streamNext =
prediction
.getHeaders()
.get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT);
if (streamNext != null && streamNext.equals("false")) {
jobs.remove(jobId);
} else if (!job.isOpen()) {
jobs.remove(job.getJobId());
logger.info(
"Connection to client got closed; Removing job: {}",
job.getPayload().getRequestId());
}
}
} else {
for (Map.Entry<String, Job> j : jobs.entrySet()) {
if (j.getValue() == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with non 200 status code: "
+ j.getKey());
}
Job job = j.getValue();
if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
job.sendError(message.getCode(), message.getMessage());
} else {
logger.warn(
"Drop error response for inference request {} due to client timeout",
job.getPayload().getRequestId());
}
}
jobs.clear();
}

return true;
}

private void pollBatch(String threadName, WorkerState state, int batchSize)
throws InterruptedException {
boolean pollMgmtJobStatus = false;
if (jobs.isEmpty()) {
pollMgmtJobStatus =
model.pollMgmtJob(
threadName,
(state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE,
jobs);
}

if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) {
model.pollInferJob(jobs, batchSize);
}
}
}
Loading

0 comments on commit 8d12993

Please sign in to comment.