diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index c80235fc74..79cbbd34a4 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -55,6 +55,9 @@ public class ModelConfig { */ private boolean useJobTicket; + /** continuousBatching is a flag to enable continuous batching. */ + private boolean continuousBatching; + public static ModelConfig build(Map yamlMap) { ModelConfig modelConfig = new ModelConfig(); yamlMap.forEach( @@ -158,6 +161,15 @@ public static ModelConfig build(Map yamlMap) { logger.warn("Invalid useJobTicket: {}, should be true or false", v); } break; + case "continuousBatching": + if (v instanceof Boolean) { + modelConfig.setContinuousBatching((boolean) v); + } else { + logger.warn( + "Invalid continuousBatching: {}, should be true or false", + v); + } + break; default: break; } @@ -313,6 +325,14 @@ public void setUseJobTicket(boolean useJobTicket) { this.useJobTicket = useJobTicket; } + public boolean isContinuousBatching() { + return continuousBatching; + } + + public void setContinuousBatching(boolean continuousBatching) { + this.continuousBatching = continuousBatching; + } + public enum ParallelType { NONE(""), PP("pp"), diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java index 9c4b0d9e56..4cb9b25b2b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java @@ -146,4 +146,10 @@ public void sendError(int status, String error) { .asRuntimeException()); } } + + @Override + public boolean isOpen() { + return ((ServerCallStreamObserver) predictionResponseObserver) + .isCancelled(); + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/Job.java b/frontend/server/src/main/java/org/pytorch/serve/job/Job.java index a17ebff8ba..b7f559bb01 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/Job.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/Job.java @@ -44,9 +44,14 @@ public WorkerCommands getCmd() { } public boolean isControlCmd() { - return !WorkerCommands.PREDICT.equals(cmd) - && !WorkerCommands.STREAMPREDICT.equals(cmd) - && !WorkerCommands.DESCRIBE.equals(cmd); + switch (cmd) { + case PREDICT: + case STREAMPREDICT: + case DESCRIBE: + return false; + default: + return true; + } } public RequestInput getPayload() { @@ -73,4 +78,6 @@ public abstract void response( Map responseHeaders); public abstract void sendError(int status, String error); + + public abstract boolean isOpen(); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java b/frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java index 18e3f11caa..8bc1c1aeb4 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java @@ -4,6 +4,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpContent; @@ -258,4 +259,10 @@ public CompletableFuture getResponsePromise() { public void setResponsePromise(CompletableFuture responsePromise) { this.responsePromise = responsePromise; } + + @Override + public boolean isOpen() { + Channel c = ctx.channel(); + return c.isOpen(); + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/codec/CodecUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/codec/CodecUtils.java index cfcdadaf3f..491ba704d1 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/codec/CodecUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/codec/CodecUtils.java @@ -1,7 +1,8 @@ package org.pytorch.serve.util.codec; import io.netty.buffer.ByteBuf; -import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; @@ -10,24 +11,27 @@ public final class CodecUtils { public static final int END = -1; public static final int BUFFER_UNDER_RUN = -3; + public static final long TIMEOUT_IN_MILLIS = 100; private CodecUtils() {} public static int readLength(ByteBuf byteBuf, int maxLength) { int size = byteBuf.readableBytes(); + if (size < 4) { - return BUFFER_UNDER_RUN; + throw new NotEnoughDataDecoderException("Did not receive enough data."); } int len = byteBuf.readInt(); if (len > maxLength) { - throw new CorruptedFrameException( + throw new TooLongFrameException( "Message size exceed limit: " + len + "\nConsider increasing the 'max_response_size' in 'config.properties' to fix."); } + if (len > byteBuf.readableBytes()) { - return BUFFER_UNDER_RUN; + throw new NotEnoughDataDecoderException("Did not receive enough data."); } return len; } @@ -38,7 +42,7 @@ public static String readString(ByteBuf byteBuf, int len) { public static byte[] read(ByteBuf in, int len) { if (len < 0) { - throw new CorruptedFrameException("Invalid message size: " + len); + throw new NotEnoughDataDecoderException("Did not receive enough data."); } byte[] buf = new byte[len]; @@ -49,9 +53,19 @@ public static byte[] read(ByteBuf in, int len) { public static Map readMap(ByteBuf in, int len) { HashMap ret = new HashMap<>(); for (; len > 0; len--) { - int l = readLength(in, in.readableBytes()); + int l = + readLength( + in, + 6500000); // We replace len here with 6500000 as a workaround before we + // can fix the whole otf. Basically, were mixing up bytes + // (expected by readLength) and number of entries (given to + // readMap). If we only have a small number of entries our + // values in the map are not allowed to be very big as we + // compare the given number of entries with the byte size + // we're expecting after reading the length of the next + // message. String key = readString(in, l); - l = readLength(in, in.readableBytes()); + l = readLength(in, 6500000); String val = readString(in, l); ret.put(key, val); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java index a69bce79cc..57348de638 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java @@ -76,6 +76,12 @@ private void encodeRequest(RequestInput req, ByteBuf out) { out.writeInt(buf.length); out.writeBytes(buf); + if (req.isCached()) { + out.writeInt(-1); // End of List + out.writeInt(-1); // End of List + return; + } + for (Map.Entry entry : req.getHeaders().entrySet()) { encodeField(entry.getKey(), out); encodeField(entry.getValue(), out); @@ -86,6 +92,7 @@ private void encodeRequest(RequestInput req, ByteBuf out) { encodeParameter(input, out); } out.writeInt(-1); // End of List + req.setCached(true); } private void encodeParameter(InputParameter parameter, ByteBuf out) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelResponseDecoder.java b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelResponseDecoder.java index 5897ede492..ac11b8448c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelResponseDecoder.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelResponseDecoder.java @@ -3,6 +3,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException; import java.util.ArrayList; import java.util.List; import org.pytorch.serve.util.messages.ModelWorkerResponse; @@ -82,6 +83,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { resp.setPredictions(predictions); out.add(resp); completed = true; + } catch (NotEnoughDataDecoderException e) { } finally { if (!completed) { in.resetReaderIndex(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java b/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java index af5dc0f54a..5717908f0f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java @@ -13,6 +13,7 @@ public class RequestInput { private Map headers; private List parameters; private long clientExpireTS; + private boolean cached; public RequestInput(String requestId) { this.requestId = requestId; @@ -71,4 +72,12 @@ public void setClientExpireTS(long clientTimeoutInMills) { this.clientExpireTS = System.currentTimeMillis() + clientTimeoutInMills; } } + + public boolean isCached() { + return cached; + } + + public void setCached(boolean cached) { + this.cached = cached; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java index 0d8d050462..857bd7a8ff 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java @@ -17,8 +17,10 @@ public class BatchAggregator { private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class); - private Model model; - private Map jobs; + protected Model model; + protected Map jobs; + + public BatchAggregator() {} public BatchAggregator(Model model) { this.model = model; @@ -171,4 +173,10 @@ public void sendError(BaseModelRequest message, String error, int status) { } jobs.clear(); } + + public void cleanJobs() { + if (jobs != null) { + jobs.clear(); + } + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java new file mode 100644 index 0000000000..1b04521f79 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java @@ -0,0 +1,149 @@ +package org.pytorch.serve.wlm; + +import java.util.Map; +import org.pytorch.serve.job.Job; +import org.pytorch.serve.util.messages.BaseModelRequest; +import org.pytorch.serve.util.messages.ModelInferenceRequest; +import org.pytorch.serve.util.messages.ModelLoadModelRequest; +import org.pytorch.serve.util.messages.ModelWorkerResponse; +import org.pytorch.serve.util.messages.Predictions; +import org.pytorch.serve.util.messages.RequestInput; +import org.pytorch.serve.util.messages.WorkerCommands; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ContinuousBatching extends BatchAggregator { + private static final Logger logger = LoggerFactory.getLogger(ContinuousBatching.class); + + public ContinuousBatching(Model model) { + super(model); + } + + public BaseModelRequest getRequest(String threadName, WorkerState state) + throws InterruptedException { + int batchQuota = model.getBatchSize() - jobs.size(); + + ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName()); + + pollBatch(threadName, state, batchQuota); + + if (model.isUseJobTicket() && jobs.isEmpty()) { + model.decNumJobTickets(); + return req; + } + + for (Job j : jobs.values()) { + if (j.isControlCmd()) { + if (jobs.size() > 1) { + throw new IllegalStateException( + "Received more than 1 control command. " + + "Control messages should be processed/retrieved one at a time."); + } + RequestInput input = j.getPayload(); + int gpuId = -1; + String gpu = input.getStringParameter("gpu"); + if (gpu != null) { + gpuId = Integer.parseInt(gpu); + } + return new ModelLoadModelRequest(model, gpuId); + } else { + if (j.getCmd() == WorkerCommands.STREAMPREDICT) { + req.setCommand(WorkerCommands.STREAMPREDICT); + } + j.setScheduled(); + req.addRequest(j.getPayload()); + } + } + return req; + } + + /** + * @param message: a response of a batch inference requests + * @return - true: either a non-stream response or last stream response is sent - false: a + * stream response (not include the last stream) is sent + */ + public boolean sendResponse(ModelWorkerResponse message) { + // TODO: Handle prediction level code + if (message.getCode() == 200) { + if (message.getPredictions().isEmpty()) { + // The jobs size is always 1 in the case control command + for (Map.Entry j : jobs.entrySet()) { + Job job = j.getValue(); + if (job.isControlCmd()) { + jobs.clear(); + return true; + } + } + } + for (Predictions prediction : message.getPredictions()) { + String jobId = prediction.getRequestId(); + Job job = jobs.get(jobId); + + if (job == null) { + throw new IllegalStateException( + "Unexpected job in sendResponse() with 200 status code: " + jobId); + } + + if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) { + job.response( + prediction.getResp(), + prediction.getContentType(), + prediction.getStatusCode(), + prediction.getReasonPhrase(), + prediction.getHeaders()); + } else { + logger.warn( + "Drop response for inference request {} due to client timeout", + job.getPayload().getRequestId()); + } + String streamNext = + prediction + .getHeaders() + .get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT); + if (streamNext != null && streamNext.equals("false")) { + jobs.remove(jobId); + } else if (!job.isOpen()) { + jobs.remove(job.getJobId()); + logger.info( + "Connection to client got closed; Removing job: {}", + job.getPayload().getRequestId()); + } + } + } else { + for (Map.Entry j : jobs.entrySet()) { + if (j.getValue() == null) { + throw new IllegalStateException( + "Unexpected job in sendResponse() with non 200 status code: " + + j.getKey()); + } + Job job = j.getValue(); + if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) { + job.sendError(message.getCode(), message.getMessage()); + } else { + logger.warn( + "Drop error response for inference request {} due to client timeout", + job.getPayload().getRequestId()); + } + } + jobs.clear(); + } + + return true; + } + + private void pollBatch(String threadName, WorkerState state, int batchSize) + throws InterruptedException { + boolean pollMgmtJobStatus = false; + if (jobs.isEmpty()) { + pollMgmtJobStatus = + model.pollMgmtJob( + threadName, + (state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE, + jobs); + } + + if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) { + model.pollInferJob(jobs, batchSize); + } + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index b8e5fc414b..d7143e7f86 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -65,10 +65,12 @@ public class Model { private boolean useJobTicket; private AtomicInteger numJobTickets; + private boolean continuousBatching; public Model(ModelArchive modelArchive, int queueSize) { this.modelArchive = modelArchive; if (modelArchive != null && modelArchive.getModelConfig() != null) { + continuousBatching = modelArchive.getModelConfig().isContinuousBatching(); if (modelArchive.getModelConfig().getParallelLevel() > 1 && modelArchive.getModelConfig().getParallelType() != ModelConfig.ParallelType.NONE) { @@ -245,6 +247,95 @@ public void addFirst(Job job) { jobsDb.get(DEFAULT_DATA_QUEUE).addFirst(job); } + public boolean pollMgmtJob(String threadId, long waitTime, Map jobsRepo) + throws InterruptedException { + if (jobsRepo == null || threadId == null || threadId.isEmpty()) { + throw new IllegalArgumentException("Invalid input given provided"); + } + + if (!jobsRepo.isEmpty()) { + throw new IllegalArgumentException( + "The jobs repo provided contains stale jobs. Clear them!!"); + } + + LinkedBlockingDeque jobsQueue = jobsDb.get(threadId); + if (jobsQueue != null && !jobsQueue.isEmpty()) { + Job j = jobsQueue.poll(waitTime, TimeUnit.MILLISECONDS); + if (j != null) { + jobsRepo.put(j.getJobId(), j); + return true; + } + } + return false; + } + + public void pollInferJob(Map jobsRepo, int batchSize) throws InterruptedException { + LinkedBlockingDeque jobsQueue; + try { + if (isUseJobTicket()) { + incNumJobTickets(); + } + lock.lockInterruptibly(); + long maxDelay = maxBatchDelay; + boolean pollNoWait = jobsRepo.isEmpty() ? false : true; + jobsQueue = jobsDb.get(DEFAULT_DATA_QUEUE); + + Job j = null; + if (jobsRepo.isEmpty()) { + j = jobsQueue.poll(Long.MAX_VALUE, TimeUnit.MILLISECONDS); + logger.trace("get first job: {}", Objects.requireNonNull(j).getJobId()); + + jobsRepo.put(j.getJobId(), j); + // batch size always is 1 for describe request job + if (j.getCmd() == WorkerCommands.DESCRIBE) { + if (jobsRepo.isEmpty()) { + jobsRepo.put(j.getJobId(), j); + return; + } else { + jobsQueue.addFirst(j); + return; + } + } + } + + long begin = System.currentTimeMillis(); + for (int i = 0; i < batchSize - 1; ++i) { + if (pollNoWait) { + j = jobsQueue.poll(); + } else { + j = jobsQueue.poll(maxDelay, TimeUnit.MILLISECONDS); + } + if (j == null) { + break; + } + long end = System.currentTimeMillis(); + // job batch size always is 1 when request is describe prediction + if (j.getCmd() == WorkerCommands.DESCRIBE) { + // Add the job back into the jobsQueue + jobsQueue.addFirst(j); + break; + } + maxDelay -= end - begin; + begin = end; + if (j.getPayload().getClientExpireTS() > System.currentTimeMillis()) { + jobsRepo.put(j.getJobId(), j); + } else { + logger.warn( + "Drop inference request {} due to client timeout", + j.getPayload().getRequestId()); + } + if (maxDelay <= 0) { + break; + } + } + logger.trace("sending jobs, size: {}", jobsRepo.size()); + } finally { + if (lock.isHeldByCurrentThread()) { + lock.unlock(); + } + } + } + public void pollBatch(String threadId, long waitTime, Map jobsRepo) throws InterruptedException { if (jobsRepo == null || threadId == null || threadId.isEmpty()) { @@ -420,4 +511,8 @@ public int getPendingRequestsInJobQueue() { return 0; } + + public boolean isContinuousBatching() { + return continuousBatching; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index d944e9592d..a64f4ad5b3 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -228,7 +228,12 @@ private void addThreads( } } - BatchAggregator aggregator = new BatchAggregator(model); + BatchAggregator aggregator; + if (model.isContinuousBatching()) { + aggregator = new ContinuousBatching(model); + } else { + aggregator = new BatchAggregator(model); + } int currentPort = model.getParallelLevel() > 1 ? configManager.isDebug() diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index 27af027c99..8eff8346ab 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -182,6 +182,8 @@ public void run() { currentThread.set(thread); BaseModelRequest req = null; int status = HttpURLConnection.HTTP_INTERNAL_ERROR; + // in case of retry + aggregator.cleanJobs(); try { connect(); @@ -205,8 +207,6 @@ public void run() { backendChannel.get(i).writeAndFlush(req).sync(); } - boolean isStreaming = - req.getCommand() == WorkerCommands.STREAMPREDICT ? true : false; ModelWorkerResponse reply = null; boolean jobDone = false; diff --git a/test/pytest/conftest.py b/test/pytest/conftest.py index 6b16b5a6e8..99abf212ed 100644 --- a/test/pytest/conftest.py +++ b/test/pytest/conftest.py @@ -54,11 +54,11 @@ def model_store(tmp_path_factory): def torchserve(model_store): test_utils.torchserve_cleanup() - test_utils.start_torchserve( + pipe = test_utils.start_torchserve( model_store=model_store, no_config_snapshots=True, gen_mar=False ) - yield + yield pipe test_utils.torchserve_cleanup() diff --git a/test/pytest/test_continuous_batching.py b/test/pytest/test_continuous_batching.py new file mode 100644 index 0000000000..2d6974510f --- /dev/null +++ b/test/pytest/test_continuous_batching.py @@ -0,0 +1,250 @@ +import json +import shutil +from argparse import Namespace +from pathlib import Path +from queue import Empty +from unittest.mock import MagicMock, patch +from zipfile import ZIP_STORED, ZipFile + +import pytest +import requests +import test_utils +import torch +from test_data.streaming.stream_handler import StreamingHandler + +from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext + +CURR_FILE_PATH = Path(__file__).parent + + +@pytest.fixture(scope="module") +def model_name(): + yield "streaming_handler" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return tmp_path_factory.mktemp(model_name) + + +@pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name): + mar_file_path = Path(work_dir).joinpath(model_name + ".mar") + + args = Namespace( + model_name=model_name, + version="1.0", + model_file=CURR_FILE_PATH.joinpath( + "test_data", "streaming", "fake_streaming_model.py" + ).as_posix(), + handler=CURR_FILE_PATH.joinpath( + "test_data", "streaming", "stream_handler.py" + ).as_posix(), + serialized_file=None, + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=CURR_FILE_PATH.joinpath( + "test_data", "streaming", "model_config.yaml" + ).as_posix(), + extra_files=None, + ) + + mock = MagicMock() + mock.parse_args = MagicMock(return_value=args) + with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs + with patch( + "model_archiver.model_packaging_utils.zipfile.ZipFile", + lambda x, y, _: ZipFile(x, y, ZIP_STORED), + ): + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + # mar_file_path.unlink(missing_ok=True) + + +@pytest.fixture(scope="module", name="model_name_and_stdout") +def register_model(mar_file_path, model_store, torchserve): + """ + Register the model in torchserve + """ + shutil.copy(mar_file_path, model_store) + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + params = ( + ("model_name", model_name), + ("url", file_name), + ("initial_workers", "1"), + ("synchronous", "true"), + ("batch_size", "2"), + ) + + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name, torchserve + + test_utils.unregister_model(model_name) + + +def test_echo_stream_inference(model_name_and_stdout): + model_name, _ = model_name_and_stdout + responses = [] + data = [ + { + "prompt": "The capital of France", + "max_new_tokens": 5, + }, + { + "prompt": "Europe is", + "max_new_tokens": 10, + }, + { + "prompt": "The US are", + "max_new_tokens": 15, + }, + { + "prompt": "When travelling to NYC", + "max_new_tokens": 5, + }, + ] + for d in data: + res = requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + data=json.dumps(d), + stream=True, + ) + + responses.append(res) + assert all(r.headers["Transfer-Encoding"] == "chunked" for r in responses) + + all_predictions = [] + for idx, d in enumerate(data): + prediction = [] + for chunk in responses[idx].iter_content(chunk_size=None): + if chunk: + prediction.append(chunk.decode("utf-8")) + + all_predictions.append("".join(json.loads(p)["text"] for p in prediction)) + + assert all_predictions[0] == "The capital of France, Paris, is home" + assert ( + all_predictions[1] == "Europe is a country of immigrants, and it is a country" + ) + assert ( + all_predictions[2] + == "The US are not going to be able to do that. They're going to have to" + ) + assert all_predictions[3] == "When travelling to NYC, I was able to" + + +def test_decoding_stage(monkeypatch): + monkeypatch.syspath_prepend((CURR_FILE_PATH / "test_data" / "streaming")) + + handler = StreamingHandler() + ctx = MockContext( + model_pt_file=None, + model_dir=(CURR_FILE_PATH / "test_data" / "streaming").as_posix(), + model_file="fake_streaming_model.py", + ) + ctx.model_yaml_config["handler"] = {"modelId": "gpt2"} + + torch.manual_seed(42 * 42) + handler.initialize(ctx) + + handler.context = ctx + + device = next(iter(handler.model.parameters())).device + + ctx.cache = { + "id1": { + "encoded": { + "input_ids": torch.randint(42, (1, 5), device=device), + "attention_mask": torch.ones((1, 5), dtype=int, device=device), + "past_key_values": None, + }, + }, + "id2": { + "encoded": { + "input_ids": torch.randint(42, (1, 8), device=device), + "attention_mask": torch.ones((1, 8), dtype=int, device=device), + "past_key_values": None, + } + }, + } + ctx.cache["id1"]["encoded"]["attention_mask"][0, :2] = 0 + + res = handler._run_prefill("id1") + res = handler._run_prefill("id2") + + res = handler._run_decode(["id1"]) + + assert len(res["id1"]["ids"]) == len(res["id1"]["text"]) == 1 + # assert res["id1"]["ids"][0] == 62 + + assert ctx.cache["id1"]["encoded"]["input_ids"].size()[-1] == 5 + assert ctx.cache["id1"]["encoded"]["attention_mask"].size()[-1] == 5 + + res = handler._run_decode(["id1", "id2"]) + assert ctx.cache["id1"]["encoded"]["input_ids"].size()[-1] == 10 + assert ctx.cache["id1"]["encoded"]["attention_mask"].size()[-1] == 10 + + assert ctx.cache["id2"]["encoded"]["input_ids"].size()[-1] == 10 + assert ctx.cache["id2"]["encoded"]["attention_mask"].size()[-1] == 10 + + res = handler._run_decode(["id1"]) + assert ctx.cache["id1"]["encoded"]["input_ids"].size()[-1] == 7 + assert ctx.cache["id1"]["encoded"]["attention_mask"].size()[-1] == 7 + + res = handler._run_decode(["id1", "id2"]) + assert ctx.cache["id1"]["encoded"]["input_ids"].size()[-1] == 11 + assert ctx.cache["id1"]["encoded"]["attention_mask"].size()[-1] == 11 + + assert ctx.cache["id2"]["encoded"]["input_ids"].size()[-1] == 11 + assert ctx.cache["id2"]["encoded"]["attention_mask"].size()[-1] == 11 + + +def test_closed_connection(model_name_and_stdout): + model_name, stdout = model_name_and_stdout + + # Empty queue + while not stdout.empty(): + stdout.get_nowait() + + data = { + "prompt": "The capital of France", + "max_new_tokens": 500, + } + + with requests.Session() as s: + res = s.post( + url=f"http://localhost:8080/predictions/{model_name}", + data=json.dumps(data), + stream=True, + ) + + for chunk in res.iter_content(chunk_size=None): + # Close connection after the first id has been received + break + + lines = [] + while True: + try: + lines.append(stdout.get(timeout=5)) + except Empty: + assert 0, "Queue timed out" + + if "Connection to client got closed; Removing job:" in lines[-1]: + break + + # We expect the model to only run two times at most due to the closed connection + assert len(list(filter(lambda x: "Backend received inference at" in x, lines))) <= 2 diff --git a/test/pytest/test_data/__init__.py b/test/pytest/test_data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/pytest/test_data/streaming/__init__.py b/test/pytest/test_data/streaming/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/pytest/test_data/streaming/fake_streaming_model.py b/test/pytest/test_data/streaming/fake_streaming_model.py new file mode 100644 index 0000000000..676a0c44e4 --- /dev/null +++ b/test/pytest/test_data/streaming/fake_streaming_model.py @@ -0,0 +1,9 @@ +import torch + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x diff --git a/test/pytest/test_data/streaming/model_config.yaml b/test/pytest/test_data/streaming/model_config.yaml new file mode 100644 index 0000000000..cec39fb88e --- /dev/null +++ b/test/pytest/test_data/streaming/model_config.yaml @@ -0,0 +1,3 @@ +continuousBatching: true +handler: + modelId: gpt2 diff --git a/test/pytest/test_data/streaming/stream.ipynb b/test/pytest/test_data/streaming/stream.ipynb new file mode 100644 index 0000000000..4532aa7e48 --- /dev/null +++ b/test/pytest/test_data/streaming/stream.ipynb @@ -0,0 +1,1220 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/miniconda3/envs/serve/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transformers import GPT2LMHeadModel, AutoTokenizer\n", + "\n", + "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inference time: 1.781\n", + "tensor([ 464, 3139, 286, 4881, 220, 1849, 271, 262, 3139, 286, 262, 4141,\n", + " 2066, 13, 383, 4141, 2066, 318, 257, 1181, 286, 262, 1242, 2422,\n", + " 290, 3034, 1080, 13, 383, 4141, 2066, 318, 257, 1181, 286, 262,\n", + " 1242, 2422, 290, 3034, 1080, 13, 383, 4141, 2066, 318, 257, 1181,\n", + " 286, 262, 1242, 2422, 290, 3034, 1080])\n" + ] + } + ], + "source": [ + "encoded = tokenizer(\"The capital of France \", return_tensors=\"pt\")\n", + "import time\n", + "st = time.perf_counter()\n", + "generate_output = model.generate(**encoded, use_cache=True, return_dict_in_generate=True, max_new_tokens=50)\n", + "print(f\"Inference time: {time.perf_counter()-st:.3f}\")\n", + "print(generate_output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 464, 3139, 286, 4881, 220]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}\n", + "tensor([ 464, 3139, 286, 4881, 220, 1849])\n" + ] + } + ], + "source": [ + "model_config={\n", + " \"use_cache\":True,\n", + " \"return_dict_in_generate\":True,\n", + " \"max_new_tokens\":1,\n", + "}\n", + "print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "old_update = model._update_model_kwargs_for_generation\n", + "extracted = {}\n", + "import types\n", + "def new_func(self,*args, **kwargs):\n", + " extracted[\"past_key_values\"] = args[0][\"past_key_values\"]\n", + " return old_update(*args, **kwargs)\n", + "\n", + "model._update_model_kwargs_for_generation = types.MethodType(new_func, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 5, 64]\n", + "tensor([ 464, 3139, 286, 4881, 220, 1849])\n" + ] + } + ], + "source": [ + "\n", + "output = model.generate(**encoded, **model_config)\n", + "print([len(extracted[\"past_key_values\"]), len(extracted[\"past_key_values\"][0])] + list(extracted[\"past_key_values\"][0][0].size()))\n", + "print(output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 464, 3139, 286, 4881, 220, 1849, 271])\n" + ] + } + ], + "source": [ + "import torch\n", + "encoded = {\n", + " \"input_ids\": output.sequences,\n", + " \"attention_mask\": torch.concat((encoded[\"attention_mask\"], torch.ones((1,1), dtype=torch.int64)), dim=1),\n", + " \"past_key_values\": extracted[\"past_key_values\"],\n", + "}\n", + "# print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 5, 64]\n", + "[12, 2, 1, 12, 6, 64]\n", + "[12, 2, 1, 12, 7, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 8, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 9, 64]\n", + "[12, 2, 1, 12, 10, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 11, 64]\n", + "[12, 2, 1, 12, 12, 64]\n", + "[12, 2, 1, 12, 13, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 14, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 15, 64]\n", + "[12, 2, 1, 12, 16, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 17, 64]\n", + "[12, 2, 1, 12, 18, 64]\n", + "[12, 2, 1, 12, 19, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 20, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 21, 64]\n", + "[12, 2, 1, 12, 22, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 23, 64]\n", + "[12, 2, 1, 12, 24, 64]\n", + "[12, 2, 1, 12, 25, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 26, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 27, 64]\n", + "[12, 2, 1, 12, 28, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 29, 64]\n", + "[12, 2, 1, 12, 30, 64]\n", + "[12, 2, 1, 12, 31, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 32, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 33, 64]\n", + "[12, 2, 1, 12, 34, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 35, 64]\n", + "[12, 2, 1, 12, 36, 64]\n", + "[12, 2, 1, 12, 37, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 38, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 39, 64]\n", + "[12, 2, 1, 12, 40, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 41, 64]\n", + "[12, 2, 1, 12, 42, 64]\n", + "[12, 2, 1, 12, 43, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 44, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 45, 64]\n", + "[12, 2, 1, 12, 46, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 47, 64]\n", + "[12, 2, 1, 12, 48, 64]\n", + "[12, 2, 1, 12, 49, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 50, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 51, 64]\n", + "[12, 2, 1, 12, 52, 64]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 53, 64]\n", + "[12, 2, 1, 12, 54, 64]\n", + "Inference time: 1.860\n", + "tensor([ 464, 3139, 286, 4881, 220, 1849, 271, 262, 3139, 286, 262, 4141,\n", + " 2066, 13, 383, 4141, 2066, 318, 257, 1181, 286, 262, 1242, 2422,\n", + " 290, 3034, 1080, 13, 383, 4141, 2066, 318, 257, 1181, 286, 262,\n", + " 1242, 2422, 290, 3034, 1080, 13, 383, 4141, 2066, 318, 257, 1181,\n", + " 286, 262, 1242, 2422, 290, 3034, 1080])\n" + ] + } + ], + "source": [ + "encoded = tokenizer(\"The capital of France \", return_tensors=\"pt\")\n", + "st = time.perf_counter()\n", + "for _ in range(50):\n", + " output = model.generate(**encoded, **model_config)\n", + " encoded = {\n", + " \"input_ids\": output.sequences,\n", + " \"attention_mask\": torch.concat((encoded[\"attention_mask\"], torch.ones((1,1), dtype=torch.int64)), dim=1),\n", + " \"past_key_values\": extracted[\"past_key_values\"],\n", + " }\n", + " print([len(extracted[\"past_key_values\"]), len(extracted[\"past_key_values\"][0])] + list(extracted[\"past_key_values\"][0][0].size()))\n", + "print(f\"Inference time: {time.perf_counter()-st:.3f}\")\n", + "print(output.sequences[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "assert all(generate_output.sequences[0] == output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 464, 3139, 286, 4881, 318, 220],\n", + " [50256, 32423, 49696, 457, 38863, 18042]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1],\n", + " [0, 1, 1, 1, 1, 1]])}\n" + ] + } + ], + "source": [ + "tokenizer.padding_side=\"left\"\n", + "encoded = tokenizer([\"The capital of France is \", \"Die Hauptstadt von\"], return_tensors=\"pt\", padding=\"longest\")\n", + "print(encoded)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 2, 12, 6, 64]\n", + "tensor([ 464, 3139, 286, 4881, 318, 220, 1849])\n", + "tensor([50256, 32423, 49696, 457, 38863, 18042, 509])\n" + ] + } + ], + "source": [ + "output = model.generate(**encoded, **model_config)\n", + "print([len(extracted[\"past_key_values\"]), len(extracted[\"past_key_values\"][0])] + list(extracted[\"past_key_values\"][0][0].size()))\n", + "print(output.sequences[0])\n", + "print(output.sequences[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "padded_kv_cache = copy.deepcopy(extracted[\"past_key_values\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([32423, 49696, 457, 38863, 18042, 509])\n" + ] + } + ], + "source": [ + "encoded = tokenizer([\"Die Hauptstadt von\"], return_tensors=\"pt\")\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 1, 12, 5, 64]\n" + ] + } + ], + "source": [ + "def print_kv_dims(kv):\n", + " print([len(kv), len(kv[0])] + list(kv[0][0].size()))\n", + "print_kv_dims(extracted[\"past_key_values\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12, 2, 2, 12, 6, 64]\n" + ] + } + ], + "source": [ + "\n", + "print_kv_dims(padded_kv_cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-4.7684e-07, 1.1921e-06, 2.3842e-07, 1.1921e-07, 5.3644e-07,\n", + " -2.3842e-07, 1.1921e-06, 0.0000e+00, -5.9605e-07, 8.9407e-08,\n", + " -8.9407e-08, 1.1921e-07, -6.7055e-08, 3.5763e-07, 4.7684e-07,\n", + " 2.3842e-07, -7.1526e-07, 2.9802e-07, 4.7684e-07, -4.7684e-07,\n", + " 4.7684e-07, -4.4703e-08, -5.9605e-07, 1.1921e-07, 0.0000e+00,\n", + " 1.7881e-07, -1.7881e-07, -7.7486e-07, 2.3842e-07, 1.1921e-07,\n", + " 7.1526e-07, -1.1921e-07, -7.1526e-07, 5.9605e-08, 5.3644e-07,\n", + " 2.3842e-07, 4.7684e-07, -5.9605e-07, -3.5763e-07, 5.9605e-08,\n", + " 4.1723e-07, 3.5763e-07, 1.1921e-06, -2.3842e-07, 8.9407e-07,\n", + " 9.5367e-07, -4.1723e-07, 2.3842e-07, 7.4506e-08, -4.7684e-07,\n", + " 2.0862e-07, -5.9605e-08, 1.1921e-07, -3.5763e-07, -1.7881e-07,\n", + " 9.5367e-07, -8.3074e-07, -3.3528e-08, -2.9802e-07, -3.5763e-07,\n", + " 1.1921e-07, 1.7881e-07, -2.6822e-07, -9.5367e-07]])\n" + ] + } + ], + "source": [ + "print(extracted[\"past_key_values\"][0][0][0,0,...] - padded_kv_cache[0][0][1,0,1:,:])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-4.7684e-07, 1.1921e-06, 2.3842e-07, 1.1921e-07, 5.3644e-07,\n", + " -2.3842e-07, 1.1921e-06, 0.0000e+00, -5.9605e-07, 8.9407e-08,\n", + " -8.9407e-08, 1.1921e-07, -6.7055e-08, 3.5763e-07, 4.7684e-07,\n", + " 2.3842e-07, -7.1526e-07, 2.9802e-07, 4.7684e-07, -4.7684e-07,\n", + " 4.7684e-07, -4.4703e-08, -5.9605e-07, 1.1921e-07, 0.0000e+00,\n", + " 1.7881e-07, -1.7881e-07, -7.7486e-07, 2.3842e-07, 1.1921e-07,\n", + " 7.1526e-07, -1.1921e-07, -7.1526e-07, 5.9605e-08, 5.3644e-07,\n", + " 2.3842e-07, 4.7684e-07, -5.9605e-07, -3.5763e-07, 5.9605e-08,\n", + " 4.1723e-07, 3.5763e-07, 1.1921e-06, -2.3842e-07, 8.9407e-07,\n", + " 9.5367e-07, -4.1723e-07, 2.3842e-07, 7.4506e-08, -4.7684e-07,\n", + " 2.0862e-07, -5.9605e-08, 1.1921e-07, -3.5763e-07, -1.7881e-07,\n", + " 9.5367e-07, -8.3074e-07, -3.3528e-08, -2.9802e-07, -3.5763e-07,\n", + " 1.1921e-07, 1.7881e-07, -2.6822e-07, -9.5367e-07])\n", + "tensor([-1.9769, 2.8057, 1.7984, 1.7875, 0.5844, 2.1871, 1.4393, -0.5568,\n", + " -0.9254, -0.3672, 0.2673, 1.1119, -0.0763, 1.2123, -1.3547, 0.5947,\n", + " 0.7469, -0.6633, 1.7078, -0.8085, -1.6846, 0.0351, -1.0112, -0.9357,\n", + " 0.3067, -0.8318, -0.5093, -0.7956, -0.5246, -1.0272, 0.7018, -0.6455,\n", + " -2.2052, 0.5388, 0.8386, -0.5252, 1.3803, 1.6268, 1.2225, -0.5823,\n", + " 0.5009, 1.0283, -1.7727, 1.0943, 0.5510, -2.3148, 0.8457, -1.1288,\n", + " 0.0967, 1.5846, 0.4326, -0.2651, 1.6881, 0.5560, -0.3775, 2.8351,\n", + " -0.0109, 0.0034, 0.8708, 0.7571, -1.3306, -0.8162, 0.2832, 2.1278])\n", + "tensor([-1.9769, 2.8057, 1.7984, 1.7875, 0.5844, 2.1871, 1.4393, -0.5568,\n", + " -0.9254, -0.3672, 0.2673, 1.1119, -0.0763, 1.2123, -1.3547, 0.5947,\n", + " 0.7469, -0.6633, 1.7078, -0.8085, -1.6846, 0.0351, -1.0112, -0.9357,\n", + " 0.3067, -0.8318, -0.5093, -0.7956, -0.5246, -1.0272, 0.7018, -0.6455,\n", + " -2.2052, 0.5388, 0.8386, -0.5252, 1.3803, 1.6268, 1.2225, -0.5823,\n", + " 0.5009, 1.0283, -1.7727, 1.0943, 0.5510, -2.3148, 0.8457, -1.1288,\n", + " 0.0967, 1.5846, 0.4326, -0.2651, 1.6881, 0.5560, -0.3775, 2.8351,\n", + " -0.0109, 0.0034, 0.8708, 0.7571, -1.3306, -0.8162, 0.2832, 2.1278])\n" + ] + } + ], + "source": [ + "print(extracted[\"past_key_values\"][0][0][0,0,-1] - padded_kv_cache[0][0][1,0,-1,:])\n", + "print(extracted[\"past_key_values\"][0][0][0,0,-1])\n", + "print(padded_kv_cache[0][0][1,0,-1,:])\n", + "# Could the difference be the leakiness of the attention mask in the attention block? mask is not binary but 1 and float32.min\n", + "# see https://github.com/huggingface/transformers/blob/c3ecf2d95d6a9f614d968af2f8b4e317f381e5ec/src/transformers/models/gpt2/modeling_gpt2.py#L823C82-L823C82" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 464, 3139, 286, 4881],\n", + " [ 464, 3139, 286, 4881]]), 'attention_mask': tensor([[1, 1, 1, 1],\n", + " [1, 1, 1, 1]])}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 464, 3139, 286, 4881, 11, 6342, 11, 318, 1363])\n", + "tensor([ 464, 3139, 286, 4881, 11, 6342, 11, 318, 1363])\n", + "{'input_ids': tensor([[50256, 50256, 464, 3139, 286, 4881]]), 'attention_mask': tensor([[0, 0, 1, 1, 1, 1]])}\n", + "tensor([50256, 50256, 464, 3139, 286, 4881, 11, 6342, 11, 318,\n", + " 1363])\n", + "The capital of France, Paris, is home\n" + ] + } + ], + "source": [ + "model_config[\"max_new_tokens\"]=5\n", + "encoded = tokenizer([\"The capital of France\", \"The capital of France\"], return_tensors=\"pt\")\n", + "print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])\n", + "print(output.sequences[1])\n", + "\n", + "encoded = tokenizer([\"The capital of France\"], return_tensors=\"pt\", max_length=6, padding='max_length', truncation=True)\n", + "print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])\n", + "print(tokenizer.decode(output.sequences[0],skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inference time: 0.414\n", + "tensor([ 464, 3139, 286, 4881, 11, 6342, 11, 318, 1363])\n", + "tensor([ 464, 3139, 286, 4881, 11, 6342, 11, 318, 1363])\n" + ] + } + ], + "source": [ + "model_config[\"max_new_tokens\"]=1\n", + "encoded = tokenizer([\"The capital of France\", \"The capital of France\"], return_tensors=\"pt\")\n", + "st = time.perf_counter()\n", + "for _ in range(5):\n", + " output = model.generate(**encoded, **model_config)\n", + " encoded = {\n", + " \"input_ids\": output.sequences,\n", + " \"attention_mask\": torch.concat((encoded[\"attention_mask\"], torch.ones((2,1), dtype=torch.int64)), dim=1),\n", + " \"past_key_values\": extracted[\"past_key_values\"],\n", + " }\n", + " # print([len(extracted[\"past_key_values\"]), len(extracted[\"past_key_values\"][0])] + list(extracted[\"past_key_values\"][0][0].size()))\n", + "print(f\"Inference time: {time.perf_counter()-st:.3f}\")\n", + "print(output.sequences[0])\n", + "print(output.sequences[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[50256, 50256, 50256, 50256, 50256, 464, 3139, 286, 4881],\n", + " [ 464, 3139, 286, 4881, 11, 6342, 11, 318, 1363]]), 'attention_mask': tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([50256, 50256, 50256, 50256, 50256, 464, 3139, 286, 4881, 11,\n", + " 6342, 11, 318, 1363])\n", + "tensor([ 464, 3139, 286, 4881, 11, 6342, 11, 318, 1363, 284, 262, 995,\n", + " 338, 4387])\n", + "{'input_ids': tensor([[50256, 50256, 464, 3139, 286, 4881]]), 'attention_mask': tensor([[0, 0, 1, 1, 1, 1]])}\n", + "tensor([50256, 50256, 464, 3139, 286, 4881, 11, 6342, 11, 318,\n", + " 1363])\n", + "The capital of France, Paris, is home\n" + ] + } + ], + "source": [ + "model_config[\"max_new_tokens\"]=5\n", + "encoded = tokenizer([\"The capital of France\", \"The capital of France, Paris, is home\"], return_tensors=\"pt\", padding=True)\n", + "print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])\n", + "print(output.sequences[1])\n", + "\n", + "encoded = tokenizer([\"The capital of France\"], return_tensors=\"pt\", max_length=6, padding='max_length', truncation=True)\n", + "print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])\n", + "print(tokenizer.decode(output.sequences[0],skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "padded_kv_cache = copy.deepcopy(extracted[\"past_key_values\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/miniconda3/envs/serve/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/ubuntu/miniconda3/envs/serve/lib/python3.10/site-packages/accelerate/utils/imports.py:245: UserWarning: Intel Extension for PyTorch 2.0 needs to work with PyTorch 2.0.*, but PyTorch 2.2.0.dev20230922+cu118 is found. Please switch to the matching version and run again.\n", + " warnings.warn(\n", + "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.09it/s]\n", + "/home/ubuntu/miniconda3/envs/serve/lib/python3.10/site-packages/transformers/generation/utils.py:1353: UserWarning: Using `max_length`'s default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n", + " warnings.warn(\n", + "/home/ubuntu/miniconda3/envs/serve/lib/python3.10/site-packages/transformers/generation/utils.py:1452: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 1, 450, 7483, 310, 3444]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}\n", + "tensor([ 1, 450, 7483, 310, 3444, 29892, 3681, 29892, 338, 263,\n", + " 4272, 310, 6017, 749, 29892, 1616, 29892, 322, 9257, 29889])\n", + "The capital of France, Paris, is a city of romance, art, and culture.\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " 'meta-llama/Llama-2-7b-hf',\n", + " device_map=\"balanced\",\n", + " low_cpu_mem_usage=True,\n", + " torch_dtype=torch.float16,\n", + " load_in_8bit=True,\n", + " )\n", + "\n", + "encoded = tokenizer([\"The capital of France\"], return_tensors=\"pt\", return_token_type_ids=False)\n", + "print(encoded)\n", + "output = model.generate(**encoded, use_cache=True, return_dict_in_generate=True)\n", + "print(output.sequences[0])\n", + "print(tokenizer.decode(output.sequences[0],skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "old_update = model._update_model_kwargs_for_generation\n", + "extracted = {}\n", + "import types\n", + "def new_func(self,*args, **kwargs):\n", + " extracted[\"past_key_values\"] = args[0][\"past_key_values\"]\n", + " return old_update(*args, **kwargs)\n", + "\n", + "model._update_model_kwargs_for_generation = types.MethodType(new_func, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[32, 2, 1, 32, 5, 128]\n", + "tensor([ 1, 450, 7483, 310, 3444, 29892])\n" + ] + } + ], + "source": [ + "model_config={\n", + " \"use_cache\":True,\n", + " \"return_dict_in_generate\":True,\n", + " \"max_new_tokens\":1,\n", + "}\n", + "output = model.generate(**encoded, **model_config)\n", + "print([len(extracted[\"past_key_values\"]), len(extracted[\"past_key_values\"][0])] + list(extracted[\"past_key_values\"][0][0].size()))\n", + "print(output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 1, 450, 7483, 310, 3444, 29892, 3681])\n" + ] + } + ], + "source": [ + "import torch\n", + "encoded = {\n", + " \"input_ids\": output.sequences,\n", + " \"attention_mask\": torch.concat((encoded[\"attention_mask\"], torch.ones((1,1), dtype=torch.int64)), dim=1),\n", + " \"past_key_values\": extracted[\"past_key_values\"],\n", + "}\n", + "# print(encoded)\n", + "output = model.generate(**encoded, **model_config)\n", + "print(output.sequences[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inference time: 7.655\n", + "tensor([ 1, 450, 7483, 310, 3444, 29892, 3681, 29892, 338, 263,\n", + " 4272, 310, 6017, 749, 29892, 1616, 29892, 322, 9257, 29889,\n", + " 739, 338, 884, 263, 4272, 310, 13460, 29892, 9687, 29892,\n", + " 322, 2090, 29889, 3681, 338, 263, 4272, 393, 756, 1554,\n", + " 363, 14332, 29889, 26460, 366, 526, 3063, 363, 263, 6017,\n", + " 7716, 679, 21694, 29892, 263])\n", + "The capital of France, Paris, is a city of romance, art, and culture. It is also a city of fashion, food, and fun. Paris is a city that has something for everyone. Whether you are looking for a romantic getaway, a\n" + ] + } + ], + "source": [ + "import time\n", + "encoded = tokenizer(\"The capital of France\", return_tensors=\"pt\", return_token_type_ids=False)\n", + "st = time.perf_counter()\n", + "for _ in range(50):\n", + " output = model.generate(**encoded, **model_config)\n", + " if output.sequences[0][-1] == tokenizer.eos_token_id:\n", + " break\n", + " encoded = {\n", + " \"input_ids\": output.sequences,\n", + " \"attention_mask\": torch.concat((encoded[\"attention_mask\"], torch.ones((1,1), dtype=torch.int64)), dim=1),\n", + " \"past_key_values\": extracted[\"past_key_values\"],\n", + " }\n", + " # print([len(extracted[\"past_key_values\"]), len(extracted[\"past_key_values\"][0])] + list(extracted[\"past_key_values\"][0][0].size()))\n", + "print(f\"Inference time: {time.perf_counter()-st:.3f}\")\n", + "print(output.sequences[0])\n", + "print(tokenizer.decode(output.sequences[0],skip_special_tokens=True))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "serve", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/pytest/test_data/streaming/stream_handler.py b/test/pytest/test_data/streaming/stream_handler.py new file mode 100644 index 0000000000..23ec25a79b --- /dev/null +++ b/test/pytest/test_data/streaming/stream_handler.py @@ -0,0 +1,292 @@ +import json +import logging +import types + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class StreamingHandler(BaseHandler): + def initialize(self, ctx): + super().initialize(ctx) + + ctx.cache = {} + + logger.info(f"Initialized {self.__class__}") + + # Initialize model + self.tokenizer = AutoTokenizer.from_pretrained( + ctx.model_yaml_config["handler"]["modelId"] + ) + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.model = AutoModelForCausalLM.from_pretrained( + ctx.model_yaml_config["handler"]["modelId"] + ) + if torch.cuda.is_available(): + self.model.to("cuda") + self.model.eval() + + # Replace _update_model_kwargs_for_generation of model with a method that extracts the kv cache for us + old_update = self.model._update_model_kwargs_for_generation + ctx.kv_cache = {} + + def extract_past_key_values_func(self, *args, **kwargs): + ctx.kv_cache["past_key_values"] = args[0]["past_key_values"] + return old_update(*args, **kwargs) + + self.model._update_model_kwargs_for_generation = types.MethodType( + extract_past_key_values_func, self.model + ) + + def preprocess(self, data): + assert len(self.context.request_ids.values()) <= 2 + self._clean_cache() + + prefill, decode = [], [] + for req_id, req_data in zip(self.context.request_ids.values(), data): + # Tokenizer requests which are not prefilled yet + if not req_id in self.context.cache: + data = json.loads(req_data["body"]) + encoded = self.tokenizer( + data["prompt"], return_tensors="pt", return_token_type_ids=False + ) + if torch.cuda.is_available(): + encoded = {k: v.to("cuda") for k, v in encoded.items()} + encoded["past_key_values"] = None + self.context.cache[req_id] = { + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens=data["max_new_tokens"] + ), + "encoded": encoded, + "prompt_length": len(encoded["input_ids"]), + } + prefill.append(req_id) + else: + decode.append(req_id) + return prefill, decode + + def inference(self, *args): + prefill, decode_ids = args[0] + + # Prefill requests + results = {} + for req_id in prefill: + results[req_id] = self._run_prefill(req_id) + + # Decode the rest + decode_result = self._run_decode(decode_ids) if decode_ids else {} + results.update(decode_result) + return [results[i] for i in self.context.request_ids.values()] + + def postprocess(self, x): + self.context.stopping_criteria = [ + self.context.cache[i]["stopping_criteria"] + for i in self.context.request_ids.values() + ] + return x + + @torch.no_grad() + def _run_prefill(self, req_id): + assert ( + self.context.cache[req_id]["encoded"]["past_key_values"] is None + ), "There should be no cached values" + self.context.cache[req_id]["encoded"] + output = self.model.generate( + **self.context.cache[req_id]["encoded"], + max_new_tokens=1, + return_dict_in_generate=True, + use_cache=True, + ) + # Save extracted kv cache values and adjust attention mask for next call + self.context.cache[req_id]["encoded"][ + "past_key_values" + ] = self.context.kv_cache["past_key_values"] + del self.context.kv_cache["past_key_values"] + self.context.cache[req_id]["encoded"]["input_ids"] = output.sequences + + device = next(iter(self.model.parameters())).device + dtype = torch.int64 + config = {"device": device, "dtype": dtype} + attention_mask = self.context.cache[req_id]["encoded"]["attention_mask"] + attention_mask = torch.concat( + (attention_mask, torch.ones((1, 1), **config)), dim=1 + ) + self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask + + result = { + "text": self.tokenizer.decode( + output.sequences[0], skip_special_tokens=True + ), + "ids": output.sequences[0].tolist(), + } + return result + + @torch.no_grad() + def _run_decode(self, ids): + assert len(ids) + + encoded = self._prepare_model_inputs(ids) + + outputs = self.model.generate( + **encoded, max_new_tokens=1, return_dict_in_generate=True, use_cache=True + ) + + device = next(iter(self.model.parameters())).device + dtype = torch.int64 + config = {"device": device, "dtype": dtype} + + results = {} + for idx, req_id in enumerate(ids): + self.context.cache[req_id]["encoded"][ + "past_key_values" + ] = self._collect_kv_cache_of_idx_in_batch(idx) + self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[ + idx + ].unsqueeze(0) + attention_mask = encoded["attention_mask"][idx].unsqueeze(0) + attention_mask = torch.concat( + (attention_mask, torch.ones((1, 1), **config)), dim=1 + ) + self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask + results[req_id] = { + "text": self.tokenizer.decode( + outputs.sequences[idx][-1], skip_special_tokens=True + ), + "ids": [outputs.sequences[idx][-1].item()], + } + del self.context.kv_cache["past_key_values"] + return results + + def _prepare_model_inputs(self, ids): + lengths = list( + torch.sum(self.context.cache[i]["encoded"]["attention_mask"], dim=1).item() + for i in ids + ) + max_len = max(lengths) + + device = next(iter(self.model.parameters())).device + dtype = torch.int64 + config = {"device": device, "dtype": dtype} + + input_ids = [] + attention_mask = [] + kv_cache = {} + for req_id, seq_len in zip(ids, lengths): + input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) + attention_mask.append( + self.context.cache[req_id]["encoded"]["attention_mask"] + ) + + for layer_idx, layer_kv in enumerate( + self.context.cache[req_id]["encoded"]["past_key_values"] + ): + k, v = layer_kv + kv_cache[layer_idx] = kv_cache.get(layer_idx, {}) + kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get(0, []) + [k] + kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get(1, []) + [v] + padded_len = input_ids[-1].size()[-1] + if padded_len < max_len: + # Apply padding to input_ids, attention_mask and past_key_values + n = max_len - seq_len + input_ids[-1] = torch.concat( + ( + self.tokenizer.pad_token_id + torch.zeros((1, n), **config), + input_ids[-1], + ), + dim=1, + ) + attention_mask[-1] = torch.concat( + (torch.zeros((1, n), **config), attention_mask[-1]), dim=1 + ) + + size_delta = list(kv_cache[0][0][-1].size()) + size_delta[2] = n + dtype = kv_cache[0][0][-1].dtype + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0][-1] = torch.concat( + (torch.zeros(size_delta, **config), kv_cache[layer_idx][0][-1]), + dim=2, + ) + kv_cache[layer_idx][1][-1] = torch.concat( + (torch.zeros(size_delta, **config), kv_cache[layer_idx][1][-1]), + dim=2, + ) + + elif padded_len > max_len: + # Truncate padding from input_ids, attention_mask and past_key_values + input_ids[-1] = input_ids[-1][:, -max_len:] + attention_mask[-1] = attention_mask[-1][:, -max_len:] + + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0][-1] = kv_cache[layer_idx][0][-1][ + :, :, (-max_len + 1) :, : + ] + kv_cache[layer_idx][1][-1] = kv_cache[layer_idx][1][-1][ + :, :, (-max_len + 1) :, : + ] + del self.context.cache[req_id]["encoded"]["past_key_values"] + + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0] = torch.concat(kv_cache[layer_idx][0], dim=0) + kv_cache[layer_idx][1] = torch.concat(kv_cache[layer_idx][1], dim=0) + + kv_cache = tuple( + (kv_cache[layer_idx][0], kv_cache[layer_idx][1]) + for layer_idx in range(len(kv_cache)) + ) + + encoded = { + "input_ids": torch.concat(input_ids, dim=0), + "attention_mask": torch.concat(attention_mask, dim=0), + "past_key_values": kv_cache, + } + return encoded + + def _collect_kv_cache_of_idx_in_batch(self, idx): + # The materialization of the tuple here is important for some reason (TODO: figure out why); Otherwise prediction differ + return tuple( + tuple(kv[idx, ...].unsqueeze(0) for kv in layers) + for layers in self.context.kv_cache["past_key_values"] + ) + + def _create_stopping_criteria(self, req_id, max_new_tokens=25): + class StoppingCriteria(object): + def __init__( + self, + cache, + req_id, + stop_token, + max_new_tokens, + ): + self.req_id = req_id + self.cache = cache + self.max_new_tokens = max_new_tokens + self.stop_token = stop_token + + def __call__(self, res): + self.max_new_tokens -= 1 + + if self.max_new_tokens == 0 or res["ids"][-1] == self.stop_token: + self.clean_up() + return True + return False + + def clean_up(self): + del self.cache[self.req_id] + + return StoppingCriteria( + self.context.cache, + req_id, + self.tokenizer.eos_token_id, + max_new_tokens, + ) + + def _clean_cache(self): + new_ids = set(self.context.request_ids.keys()) + for idx in self.context.kv_cache.keys(): + if idx not in new_ids: + del self.context.kv_cache[idx] diff --git a/test/pytest/test_handler.py b/test/pytest/test_handler.py index a14af050a4..eea6461534 100644 --- a/test/pytest/test_handler.py +++ b/test/pytest/test_handler.py @@ -51,6 +51,9 @@ def setup_module(module): response = requests.get( "https://torchserve.pytorch.org/mar_files/mnist.mar", allow_redirects=True ) + + os.makedirs(test_utils.MODEL_STORE, exist_ok=True) + with open(os.path.join(test_utils.MODEL_STORE, "mnist.mar"), "wb") as f: f.write(response.content) @@ -397,23 +400,3 @@ def test_huggingface_bert_model_parallel_inference(): "Running model parallel inference requuires more than one gpu, number of available gpus on thi machine is: ", number_of_gpus, ) - - -def test_echo_stream_inference(): - test_utils.start_torchserve(no_config_snapshots=True, gen_mar=False) - test_utils.register_model( - "echo_stream", "https://torchserve.pytorch.org/mar_files/echo_stream.mar" - ) - - response = requests.post( - TF_INFERENCE_API + "/predictions/echo_stream", data="foo", stream=True - ) - assert response.headers["Transfer-Encoding"] == "chunked" - - prediction = [] - for chunk in response.iter_content(chunk_size=None): - if chunk: - prediction.append(chunk.decode("utf-8")) - - assert str(" ".join(prediction)) == "hello hello hello hello world " - test_utils.unregister_model("echo_stream") diff --git a/test/pytest/test_utils.py b/test/pytest/test_utils.py index 23bd45ab7b..1deb0dd2b0 100644 --- a/test/pytest/test_utils.py +++ b/test/pytest/test_utils.py @@ -6,8 +6,10 @@ import sys import tempfile import threading +from io import TextIOWrapper from os import path from pathlib import Path +from queue import Queue from subprocess import PIPE, STDOUT, Popen import requests @@ -22,14 +24,32 @@ CODEBUILD_WD = path.abspath(path.join(__file__, "../../..")) -class PrintPipeTillTheEnd(threading.Thread): - def __init__(self, pipe): +class PrintTillTheEnd(threading.Thread): + def __init__(self, queue): super().__init__() - self.pipe = pipe + self._queue = queue def run(self): - for line in self.pipe.stdout: - print(line.decode("utf-8").strip()) + while True: + line = self._queue.get() + if not line: + break + print(line.strip()) + + +class Tee(threading.Thread): + def __init__(self, reader): + super().__init__() + self.reader = reader + self.queue1 = Queue() + self.queue2 = Queue() + + def run(self): + for line in self.reader: + self.queue1.put(line) + self.queue2.put(line) + self.queue1.put(None) + self.queue2.put(None) def start_torchserve( @@ -53,9 +73,14 @@ def start_torchserve( print(line.decode("utf8").strip()) if "Model server started" in str(line).strip(): break - print_thread = PrintPipeTillTheEnd(p) + + splitter = Tee(TextIOWrapper(p.stdout)) + splitter.start() + print_thread = PrintTillTheEnd(splitter.queue1) print_thread.start() + return splitter.queue2 + def stop_torchserve(): subprocess.run(["torchserve", "--stop", "--foreground"]) diff --git a/ts/context.py b/ts/context.py index aa5d9babda..72db4de3c5 100644 --- a/ts/context.py +++ b/ts/context.py @@ -39,8 +39,7 @@ def __init__( self._limit_max_image_pixels = True self.metrics = metrics self.model_yaml_config = model_yaml_config - # add client socket variable cl_socket to be used for send_intermediate_predict_response - self.cl_socket = None + self.stopping_criteria = None @property def system_properties(self): diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index d05c472b6b..91eb206f80 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -1,6 +1,7 @@ """ OTF Codec """ + import io import json import logging @@ -75,15 +76,26 @@ def create_predict_response( msg += struct.pack("!i", len(req_id)) msg += req_id - # Encoding Content-Type if context is None: + # Encoding Content-Type msg += struct.pack("!i", 0) # content_type + + # Encoding the per prediction HTTP response code + # status code and reason phrase set to none + msg += struct.pack("!i", code) + msg += struct.pack("!i", 0) # No code phrase is returned + # Response headers none + msg += struct.pack("!i", 0) else: if ts_stream_next is True: context.set_response_header(idx, "ts_stream_next", "true") - else: - if "true" == context.get_response_headers(idx).get("ts_stream_next"): - context.set_response_header(idx, "ts_stream_next", "false") + elif context.stopping_criteria: + ts_stream_next = ( + "false" if context.stopping_criteria[idx](ret[idx]) else "true" + ) + context.set_response_header(idx, "ts_stream_next", ts_stream_next) + elif "true" == context.get_response_headers(idx).get("ts_stream_next"): + context.set_response_header(idx, "ts_stream_next", "false") content_type = context.get_response_content_type(idx) if content_type is None or len(content_type) == 0: @@ -92,14 +104,6 @@ def create_predict_response( msg += struct.pack("!i", len(content_type)) msg += content_type.encode("utf-8") - # Encoding the per prediction HTTP response code - if context is None: - # status code and reason phrase set to none - msg += struct.pack("!i", code) - msg += struct.pack("!i", 0) # No code phrase is returned - # Response headers none - msg += struct.pack("!i", 0) - else: sc, phrase = context.get_response_status(idx) http_code = sc if sc is not None else 200 http_phrase = phrase if phrase is not None else "" diff --git a/ts/tests/unit_tests/test_otf_codec_protocol.py b/ts/tests/unit_tests/test_otf_codec_protocol.py index ba6bb8fc39..5df1c4644a 100644 --- a/ts/tests/unit_tests/test_otf_codec_protocol.py +++ b/ts/tests/unit_tests/test_otf_codec_protocol.py @@ -1,51 +1,58 @@ # coding=utf-8 - - """ On The Fly Codec tester """ +import struct +from builtins import bytes from collections import namedtuple import pytest import ts.protocol.otf_message_handler as codec -from builtins import bytes @pytest.fixture() def socket_patches(mocker): - Patches = namedtuple('Patches', ['socket']) - mock_patch = Patches(mocker.patch('socket.socket')) - mock_patch.socket.recv.return_value = b'1' + Patches = namedtuple("Patches", ["socket"]) + mock_patch = Patches(mocker.patch("socket.socket")) + mock_patch.socket.recv.return_value = b"1" return mock_patch # noinspection PyClassHasNoInit class TestOtfCodecHandler: - def test_retrieve_msg_unknown(self, socket_patches): socket_patches.socket.recv.side_effect = [b"U", b"\x00\x00\x00\x03"] with pytest.raises(ValueError, match=r"Invalid command: .*"): codec.retrieve_msg(socket_patches.socket) def test_retrieve_msg_load_gpu(self, socket_patches): - expected = {"modelName": b"model_name", "modelPath": b"model_path", - "batchSize": 1, "handler": b"handler", "gpu": 1, - "envelope": b"envelope", - "limitMaxImagePixels": True} + expected = { + "modelName": b"model_name", + "modelPath": b"model_path", + "batchSize": 1, + "handler": b"handler", + "gpu": 1, + "envelope": b"envelope", + "limitMaxImagePixels": True, + } socket_patches.socket.recv.side_effect = [ b"L", - b"\x00\x00\x00\x0a", b"model_name", - b"\x00\x00\x00\x0a", b"model_path", + b"\x00\x00\x00\x0a", + b"model_name", + b"\x00\x00\x00\x0a", + b"model_path", b"\x00\x00\x00\x01", - b"\x00\x00\x00\x07", b"handler", + b"\x00\x00\x00\x07", + b"handler", b"\x00\x00\x00\x01", - b"\x00\x00\x00\x08", b"envelope", - b"\x01" + b"\x00\x00\x00\x08", + b"envelope", + b"\x01", ] cmd, ret = codec.retrieve_msg(socket_patches.socket) @@ -53,20 +60,28 @@ def test_retrieve_msg_load_gpu(self, socket_patches): assert ret == expected def test_retrieve_msg_load_no_gpu(self, socket_patches): - expected = {"modelName": b"model_name", "modelPath": b"model_path", - "batchSize": 1, "handler": b"handler", - "envelope": b"envelope", - "limitMaxImagePixels": True} + expected = { + "modelName": b"model_name", + "modelPath": b"model_path", + "batchSize": 1, + "handler": b"handler", + "envelope": b"envelope", + "limitMaxImagePixels": True, + } socket_patches.socket.recv.side_effect = [ b"L", - b"\x00\x00\x00\x0a", b"model_name", - b"\x00\x00\x00\x0a", b"model_path", + b"\x00\x00\x00\x0a", + b"model_name", + b"\x00\x00\x00\x0a", + b"model_path", b"\x00\x00\x00\x01", - b"\x00\x00\x00\x07", b"handler", + b"\x00\x00\x00\x07", + b"handler", b"\xFF\xFF\xFF\xFF", - b"\x00\x00\x00\x08", b"envelope", - b"\x01" + b"\x00\x00\x00\x08", + b"envelope", + b"\x01", ] cmd, ret = codec.retrieve_msg(socket_patches.socket) @@ -74,92 +89,201 @@ def test_retrieve_msg_load_no_gpu(self, socket_patches): assert ret == expected def test_retrieve_msg_predict(self, socket_patches): - expected = [{ - "requestId": b"request_id", "headers": [], "parameters": [ - {"name": "input_name", - "contentType": "application/json", - "value": {"data": "value"} - } - ] - }] + expected = [ + { + "requestId": b"request_id", + "headers": [], + "parameters": [ + { + "name": "input_name", + "contentType": "application/json", + "value": {"data": "value"}, + } + ], + } + ] socket_patches.socket.recv.side_effect = [ b"I", - b"\x00\x00\x00\x0a", b"request_id", + b"\x00\x00\x00\x0a", + b"request_id", b"\xFF\xFF\xFF\xFF", - b"\x00\x00\x00\x0a", b"input_name", - b"\x00\x00\x00\x0F", b"application/json", - b"\x00\x00\x00\x0F", b'{"data":"value"}', + b"\x00\x00\x00\x0a", + b"input_name", + b"\x00\x00\x00\x0F", + b"application/json", + b"\x00\x00\x00\x0F", + b'{"data":"value"}', b"\xFF\xFF\xFF\xFF", # end of parameters - b"\xFF\xFF\xFF\xFF" # end of batch + b"\xFF\xFF\xFF\xFF", # end of batch ] cmd, ret = codec.retrieve_msg(socket_patches.socket) - assert cmd == b'I' + assert cmd == b"I" assert ret == expected def test_retrieve_msg_predict_text(self, socket_patches): - expected = [{ - "requestId": b"request_id", "headers": [], "parameters": [ - {"name": "input_name", - "contentType": "text/plain", - "value": u"text_value测试" - } - ] - }] + expected = [ + { + "requestId": b"request_id", + "headers": [], + "parameters": [ + { + "name": "input_name", + "contentType": "text/plain", + "value": "text_value测试", + } + ], + } + ] socket_patches.socket.recv.side_effect = [ b"I", - b"\x00\x00\x00\x0a", b"request_id", + b"\x00\x00\x00\x0a", + b"request_id", b"\xFF\xFF\xFF\xFF", - b"\x00\x00\x00\x0a", b"input_name", - b"\x00\x00\x00\x0a", b"text/plain", - b"\x00\x00\x00\x0a", bytes(u"text_value测试", "utf-8"), + b"\x00\x00\x00\x0a", + b"input_name", + b"\x00\x00\x00\x0a", + b"text/plain", + b"\x00\x00\x00\x0a", + bytes("text_value测试", "utf-8"), b"\xFF\xFF\xFF\xFF", # end of parameters - b"\xFF\xFF\xFF\xFF" # end of batch + b"\xFF\xFF\xFF\xFF", # end of batch ] cmd, ret = codec.retrieve_msg(socket_patches.socket) - assert cmd == b'I' + assert cmd == b"I" assert ret == expected def test_retrieve_msg_predict_binary(self, socket_patches): - expected = [{ - "requestId": b"request_id", "headers": [], "parameters": [ - {"name": "input_name", - "contentType": "", - "value": b"binary" - } - ] - }] + expected = [ + { + "requestId": b"request_id", + "headers": [], + "parameters": [ + {"name": "input_name", "contentType": "", "value": b"binary"} + ], + } + ] socket_patches.socket.recv.side_effect = [ b"I", - b"\x00\x00\x00\x0a", b"request_id", + b"\x00\x00\x00\x0a", + b"request_id", b"\xFF\xFF\xFF\xFF", - b"\x00\x00\x00\x0a", b"input_name", + b"\x00\x00\x00\x0a", + b"input_name", b"\x00\x00\x00\x00", - b"\x00\x00\x00\x06", b"binary", + b"\x00\x00\x00\x06", + b"binary", b"\xFF\xFF\xFF\xFF", # end of parameters - b"\xFF\xFF\xFF\xFF" # end of batch + b"\xFF\xFF\xFF\xFF", # end of batch ] cmd, ret = codec.retrieve_msg(socket_patches.socket) - assert cmd == b'I' + assert cmd == b"I" assert ret == expected def test_create_load_model_response(self): msg = codec.create_load_model_response(200, "model_loaded") - assert msg == b'\x00\x00\x00\xc8\x00\x00\x00\x0cmodel_loaded\xff\xff\xff\xff' + assert msg == b"\x00\x00\x00\xc8\x00\x00\x00\x0cmodel_loaded\xff\xff\xff\xff" def test_create_predict_response(self): msg = codec.create_predict_response(["OK"], {0: "request_id"}, "success", 200) - assert msg == b'\x00\x00\x00\xc8\x00\x00\x00\x07success\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00' \ - b'\x00\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02OK\xff\xff\xff\xff' + assert ( + msg + == b"\x00\x00\x00\xc8\x00\x00\x00\x07success\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00" + b"\x00\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02OK\xff\xff\xff\xff" + ) def test_create_predict_response_with_error(self): msg = codec.create_predict_response(None, {0: "request_id"}, "failed", 200) - assert msg == b'\x00\x00\x00\xc8\x00\x00\x00\x06failed\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00\x00' \ - b'\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05error\xff\xff\xff\xff' + assert ( + msg + == b"\x00\x00\x00\xc8\x00\x00\x00\x06failed\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00\x00" + b"\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05error\xff\xff\xff\xff" + ) + + def test_create_predict_response_with_context(self): + # context = MagicMock("Context") + # context.stopping_criteria = {0: lambda x: True} + # context.set_response_headers + # get_response_headers + from ts.context import Context, RequestProcessor + + ctx = Context( + "model_name", + "model_dir", + "manifest", + batch_size=2, + gpu=0, + mms_version=1.0, + ) + ctx.stopping_criteria = {0: lambda _: True, 1: lambda _: False} + ctx.request_processor = {0: RequestProcessor({}), 1: RequestProcessor({})} + + msg = codec.create_predict_response( + ["OK", "NOT OK"], + {0: "request_0", 1: "request_1"}, + "success", + 200, + context=ctx, + ) + + def read_int(m): + a = struct.unpack("!i", m[:4])[0] + del msg[:4] + return a + + def read_string(m, n): + a = m[:n].decode("utf-8") + del msg[:n] + return a + + def read_map(m, n): + ret = {} + while n: + l = read_int(m) + k = read_string(m, l) + l = read_int(m) + v = read_string(m, l) + ret[k] = v + n -= 1 + return ret + + assert read_int(msg) == 200 # code + + assert read_int(msg) == 7 # msg length + + assert read_string(msg, 7) == "success" # msg + + length = read_int(msg) + expected = ["request_0", "false", "OK", "request_1", "true", "NOT OK"] + while length != -1: + req_id = read_string(msg, length) + assert req_id == expected.pop(0) + + length = read_int(msg) + content_type = read_string(msg, length) + assert content_type == "" + + http_code = read_int(msg) + assert http_code == 200 + + length = read_int(msg) + http_phrase = read_string(msg, length) + assert http_phrase == "" + + length = read_int(msg) + kv = read_map(msg, length) + assert kv["ts_stream_next"] == expected.pop(0) + + length = read_int(msg) + pred = read_string(msg, length) + assert pred == expected.pop(0) + + length = read_int(msg) + assert length == -1