diff --git a/docs/modules/ROOT/pages/llama3.adoc b/docs/modules/ROOT/pages/llama3.adoc index 507b9b82c..b1b5468ca 100644 --- a/docs/modules/ROOT/pages/llama3.adoc +++ b/docs/modules/ROOT/pages/llama3.adoc @@ -39,6 +39,17 @@ WARNING: Models are huge, so make sure you have enough disk space. NOTE: Due to model's large size, pulling them can take time +=== Native mode + +Currently, Llama3.java only works in native mode with Early Access version's of Oracle GraalVM 24 (which can be easily downloaded with https://sdkman.io[SDKMan]). + +To achieve the best performance in native mode, it is suggested to configure the application with the following: + +[source,properties,subs=attributes+] +---- +quarkus.native.additional-build-args=-O3,-march=native +---- + == Using Llama3.java To let Llama3.java running inference on your models, add the following dependency into your project: diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java index e305d5750..2ef5d7619 100644 --- a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java @@ -96,7 +96,7 @@ public Response generate(List messages) { private InferenceResponse runInference(Llama model, Sampler sampler, Llama3.Options options, List messages) { - Llama.State state = model.createNewState(); + Llama.State state = model.createNewState(Llama3.BATCH_SIZE); ChatFormat chatFormat = new ChatFormat(model.tokenizer()); List promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages)); diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java index f29df7a31..df8d24b0c 100644 --- a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java @@ -2,6 +2,7 @@ import static dev.langchain4j.data.message.AiMessage.aiMessage; import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message; +import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE; import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler; import java.io.IOException; @@ -84,7 +85,7 @@ public void generate(List messages, StreamingResponseHandler messages, StreamingResponseHandler handler) { - Llama.State state = model.createNewState(); + Llama.State state = model.createNewState(BATCH_SIZE); ChatFormat chatFormat = new ChatFormat(model.tokenizer()); List promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages)); diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/AOT.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/AOT.java index c9d22d528..eea99d1ab 100644 --- a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/AOT.java +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/AOT.java @@ -51,7 +51,7 @@ public static PartialModel preLoadGGUF(String modelPath) { * No checksum/hash is checked for performance reasons. */ public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { - PartialModel preLoaded = AOT.PRELOADED_GGUF; + AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; if (preLoaded == null) { return null; // no pre-loaded model stored } diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/GGUF.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/GGUF.java index c74099e37..adad9504b 100644 --- a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/GGUF.java +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/GGUF.java @@ -108,7 +108,7 @@ private void loadModelImpl(FileChannel fileChannel) throws IOException { // gguf_tensor_info_t tensor_infos[header.tensor_count]; this.tensorInfos = HashMap.newHashMap(tensorCount); for (int i = 0; i < tensorCount; ++i) { - GGUFTensorInfo ti = readTensorInfo(fileChannel); + GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel); assert !tensorInfos.containsKey(ti.name); tensorInfos.put(ti.name, ti); } @@ -156,7 +156,7 @@ private GGMLType readGGMLType(FileChannel fileChannel) throws IOException { return GGMLType.fromId(ggmlTypeId); } - private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException { + private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException { // The name of the tensor. It is a standard GGUF string, with the caveat that // it must be at most 64 bytes long. String name = readString(fileChannel); // gguf_string_t name; @@ -180,7 +180,7 @@ private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOExceptio // Must be a multiple of `ALIGNMENT`. long offset = readLong(fileChannel); // uint64_t offset; assert offset % getAlignment() == 0; - return new GGUFTensorInfo(name, dimensions, ggmlType, offset); + return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset); } private String readString(FileChannel fileChannel) throws IOException { diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java index abf1bf8c5..1275ef37e 100644 --- a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java @@ -2,15 +2,17 @@ import java.nio.FloatBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.IntConsumer; +import java.util.stream.IntStream; import java.util.stream.Stream; public record Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) { - public State createNewState() { - State state = new State(configuration()); + public State createNewState(int batchsize) { + State state = new State(configuration(), batchsize); state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); return state; } @@ -97,32 +99,40 @@ public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, public static final class State { // current wave of activations - public final FloatTensor x; // activation at current time stamp (dim,) - public final FloatTensor xb; // same, but inside a residual branch (dim,) - public final FloatTensor xb2; // an additional buffer just for convenience (dim,) - public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,) - public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,) - public final FloatTensor q; // query (dim,) - public final FloatTensor k; // key (dim,) - public final FloatTensor v; // value (dim,) - public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len) + public final int batchsize; + public final FloatTensor[] x; // activation at current time stamp (dim,) + public final FloatTensor[] xb; // same, but inside a residual branch (dim,) + public final FloatTensor[] xb2; // an additional buffer just for convenience (dim,) + public final FloatTensor[] hb; // buffer for hidden dimension in the ffn (hidden_dim,) + public final FloatTensor[] hb2; // buffer for hidden dimension in the ffn (hidden_dim,) + public final FloatTensor[] q; // query (dim,) + public final FloatTensor[] k; // key (dim,) + public final FloatTensor[] v; // value (dim,) + public final FloatTensor[] att; // buffer for scores/attention values (n_heads, seq_len) public final FloatTensor logits; // output logits + // kv cache public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim) public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim) + /** last index in previous block */ + int idxPrevBlock; + public int latestToken; - State(Configuration config) { - this.x = ArrayFloatTensor.allocate(config.dim); - this.xb = ArrayFloatTensor.allocate(config.dim); - this.xb2 = ArrayFloatTensor.allocate(config.dim); - this.hb = ArrayFloatTensor.allocate(config.hiddenDim); - this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim); - this.q = ArrayFloatTensor.allocate(config.dim); - this.k = ArrayFloatTensor.allocate(config.dim); - this.v = ArrayFloatTensor.allocate(config.dim); - this.att = ArrayFloatTensor.allocate(config.numberOfHeads, config.contextLength); + State(Configuration config, int batchsize) { + this.batchsize = batchsize; + this.x = allocate(batchsize, config.dim); + this.xb = allocate(batchsize, config.dim); + this.xb2 = allocate(batchsize, config.dim); + this.hb = allocate(batchsize, config.hiddenDim); + this.hb2 = allocate(batchsize, config.hiddenDim); + this.q = allocate(batchsize, config.dim); + this.k = allocate(batchsize, config.dim); + this.v = allocate(batchsize, config.dim); + this.att = allocate(batchsize, config.numberOfHeads, config.contextLength); + idxPrevBlock = -1; + this.logits = ArrayFloatTensor.allocate(config.vocabularySize); int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)) @@ -132,6 +142,12 @@ public static final class State { } } + static FloatTensor[] allocate(int numTokens, int... dims) { + return IntStream.range(0, numTokens) + .mapToObj(i -> ArrayFloatTensor.allocate(dims)) + .toArray(FloatTensor[]::new); + } + static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { // calculate sum of squares float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi); @@ -143,7 +159,7 @@ static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); } - static FloatTensor forward(Llama model, State state, int token, int position) { + static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) { // a few convenience variables Configuration config = model.configuration(); Weights weights = model.weights(); @@ -152,44 +168,58 @@ static FloatTensor forward(Llama model, State state, int token, int position) { int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery float sqrtHeadSize = (float) Math.sqrt(headSize); + final int nTokens = tokens.length; // copy the token embedding into x - weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + Parallel.parallelFor(0, nTokens, t -> weights.token_embedding_table.copyTo(tokens[t] * dim, state.x[t], 0, dim)); // forward all the layers for (int l = 0; l < config.numberOfLayers; l++) { // attention rmsnorm - rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps); + // rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps); + final int curLayer = l; + Parallel.parallelFor(0, nTokens, + t -> rmsnorm(state.xb[t], state.x[t], weights.rms_att_weight[curLayer], dim, config.rmsNormEps)); // qkv matmuls for this position - weights.wq[l].matmul(state.xb, state.q, dim, dim); - weights.wk[l].matmul(state.xb, state.k, kvDim, dim); - weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + weights.wq[l].matmul(nTokens, state.xb, state.q, dim, dim); + weights.wk[l].matmul(nTokens, state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(nTokens, state.xb, state.v, kvDim, dim); // RoPE relative positional encoding: complex-valued rotate q and k in each head - for (int i = 0; i < dim; i += 2) { - int head_dim = i % headSize; - float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2)); - float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2)); - int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only - for (int v = 0; v < rotn; v++) { - FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key) - float v0 = vec.getFloat(i); - float v1 = vec.getFloat(i + 1); - vec.setFloat(i, v0 * fcr - v1 * fci); - vec.setFloat(i + 1, v0 * fci + v1 * fcr); + Parallel.parallelFor(0, nTokens, t -> { + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.get((position + t) * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.get((position + t) * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + for (int vi = 0; vi < rotn; vi++) { + FloatTensor vec = vi == 0 ? state.q[t] : state.k[t]; // the vector to rotate (query or key) + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } } - } + }); // save key,value at this time step (position) to our kv cache //int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience - state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); - state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + Parallel.parallelFor(0, nTokens, t -> { + state.k[t].copyTo(0, state.keyCache[curLayer], (position + t) * kvDim, kvDim); + state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim); + }); - int curLayer = l; + // If the logits are not required, the attention and FFN of the last layer can be skipped entirely. + if (!computeLogits && curLayer == config.numberOfLayers - 1) { + state.idxPrevBlock = nTokens - 1; + return null; + } // multihead attention. iterate over all heads - Parallel.parallelFor(0, config.numberOfHeads, h -> { + Parallel.parallelForLong(0, (long) nTokens * (long) config.numberOfHeads, ht -> { + int token = (int) (ht / config.numberOfHeads); + int h = (int) (ht % config.numberOfHeads); // get the query vector for this head // float* q = s.q + h * headSize; int qOffset = h * headSize; @@ -199,70 +229,83 @@ static FloatTensor forward(Llama model, State state, int token, int position) { int attOffset = h * config.contextLength; // iterate over all timesteps, including the current one - for (int t = 0; t <= position; t++) { + for (int t = 0; t <= position + token; t++) { // get the key vector for this head and at this timestep // float* k = s.key_cache + loff + t * dim + h * headSize; int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; // calculate the attention score as the dot product of q and k - float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + float score = state.q[token].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); score /= sqrtHeadSize; // save the score to the attention buffer - state.att.setFloat(attOffset + t, score); + state.att[token].setFloat(attOffset + t, score); } // softmax the scores to get attention weights, from 0..position inclusively - state.att.softmaxInPlace(attOffset, position + 1); + state.att[token].softmaxInPlace(attOffset, position + token + 1); // weighted sum of the values, store back into xb // float* xb = s.xb + h * headSize; int xbOffset = h * headSize; // memset(xb, 0, headSize * sizeof(float)); - state.xb.fillInPlace(xbOffset, headSize, 0f); + state.xb[token].fillInPlace(xbOffset, headSize, 0f); - for (int t = 0; t <= position; t++) { + for (int t = 0; t <= position + token; t++) { // get the value vector for this head and at this timestep // float* v = s.value_cache + loff + t * dim + h * headSize; int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; // get the attention weight for this timestep - float a = state.att.getFloat(attOffset + t); + float a = state.att[token].getFloat(attOffset + t); // accumulate the weighted value into xb - state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + state.xb[token].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); } }); // final matmul to get the output of the attention - weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + weights.wo[l].matmul(nTokens, state.xb, state.xb2, dim, dim); // residual connection back into x - state.x.addInPlace(state.xb2); + Parallel.parallelFor(0, nTokens, t -> { + state.x[t].addInPlace(state.xb2[t]); + }); // ffn rmsnorm - rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps); + Parallel.parallelFor(0, nTokens, t -> { + rmsnorm(state.xb[t], state.x[t], weights.rms_ffn_weight[curLayer], dim, config.rmsNormEps); + }); // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) // first calculate self.w1(x) and self.w3(x) - weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim, dim); - weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim, dim); + weights.w1[l].matmul(nTokens, state.xb, state.hb, config.hiddenDim, dim); + weights.w3[l].matmul(nTokens, state.xb, state.hb2, config.hiddenDim, dim); // SwiGLU non-linearity // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid - state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + Parallel.parallelFor(0, nTokens, t -> { + state.hb[t].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + }); // elementwise multiply with w3(x) - state.hb.multiplyInPlace(state.hb2); + Parallel.parallelFor(0, nTokens, t -> { + state.hb[t].multiplyInPlace(state.hb2[t]); + }); // final matmul to get the output of the ffn - weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim); + weights.w2[l].matmul(nTokens, state.hb, state.xb, dim, config.hiddenDim); // residual connection - state.x.addInPlace(state.xb); + Parallel.parallelFor(0, nTokens, t -> { + state.x[t].addInPlace(state.xb[t]); + }); } // final rmsnorm - rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps); + Parallel.parallelFor(0, nTokens, t -> { + rmsnorm(state.x[t], state.x[t], weights.rms_final_weight, dim, config.rmsNormEps); + }); // classifier into logits - weights.wcls.matmul(state.x, state.logits, config.vocabularySize, dim); + weights.wcls.matmul(state.x[nTokens - 1], state.logits, config.vocabularySize, dim); + state.idxPrevBlock = nTokens - 1; return state.logits; } @@ -294,6 +337,7 @@ public static List generateTokens(Llama model, State state, int startPo Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { long startNanos = System.nanoTime(); + long startGen = 0; if (maxTokens < 0 || model.configuration().contextLength < maxTokens) { maxTokens = model.configuration().contextLength; } @@ -302,34 +346,54 @@ public static List generateTokens(Llama model, State state, int startPo int nextToken; int promptIndex = 0; for (int position = startPosition; position < maxTokens; ++position) { - forward(model, state, token, position); if (promptIndex < promptTokens.size()) { - // Force-pick token from prompt. - nextToken = promptTokens.get(promptIndex++); - if (echo) { - // log prompt token (different color?) - System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + final int nTokens = Math.min(maxTokens - position, + Math.min(promptTokens.size() - promptIndex, state.batchsize)); + final int[] tokens = new int[nTokens]; + for (int i = 0; i < nTokens; i++) { + tokens[i] = promptTokens.get(promptIndex + i); + if (echo) { + // log prompt token (different color?) + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(tokens[i])))); + } } - } else { - nextToken = sampler.sampleToken(state.logits); if (echo) { - // log inferred token - System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + System.out.format("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n", position, promptIndex, + promptTokens.size(), Arrays.toString(tokens)); } - generatedTokens.add(nextToken); - if (onTokenGenerated != null) { - onTokenGenerated.accept(nextToken); - } - if (stopTokens.contains(nextToken)) { - break; + // Only compute logits on the very last batch. + boolean computeLogits = promptIndex + nTokens >= promptTokens.size(); + forward(model, state, tokens, position, computeLogits); + position += nTokens - 1; // -1 -> incremented later in the for loop + promptIndex += nTokens; + if (promptIndex < promptTokens.size()) { + continue; } + startGen = System.nanoTime(); + } else { + forward(model, state, new int[] { token }, position, true); + } + nextToken = sampler.sampleToken(state.logits); + if (echo) { + // log inferred token + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + generatedTokens.add(nextToken); + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + if (stopTokens.contains(nextToken)) { + break; } state.latestToken = token = nextToken; } long elapsedNanos = System.nanoTime() - startNanos; - int totalTokens = promptIndex + generatedTokens.size(); - System.err.printf("%n%.2f tokens/s (%d)%n", totalTokens / (elapsedNanos / 1_000_000_000.0), totalTokens); + long promptNanos = startGen - startNanos; + long genNanos = elapsedNanos - startGen + startNanos; + System.err.printf("%nprompt: %.2f tokens/s (%d) generation: %.2f tokens/s (%d)%n", + promptTokens.size() / (promptNanos / 1_000_000_000.0), promptTokens.size(), + generatedTokens.size() / (genNanos / 1_000_000_000.0), generatedTokens.size()); return generatedTokens; } diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java index 5f2c6c303..5671c81fb 100755 --- a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java @@ -27,24 +27,20 @@ import java.nio.ByteOrder; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Scanner; -import java.util.Set; +import java.util.*; import java.util.function.IntConsumer; +import java.util.function.LongConsumer; import java.util.random.RandomGenerator; import java.util.random.RandomGeneratorFactory; import java.util.stream.IntStream; +import java.util.stream.LongStream; -import jdk.incubator.vector.ByteVector; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorSpecies; +import jdk.incubator.vector.*; import sun.misc.Unsafe; public class Llama3 { + // Batch-size used in prompt evaluation. + public static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16); public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) { Sampler sampler; @@ -84,15 +80,24 @@ static void runInteractive(Llama model, Sampler sampler, Options options) { } int startPosition = 0; Scanner in = new Scanner(System.in); - while (true) { + loop: while (true) { System.out.print("> "); System.out.flush(); String userText = in.nextLine(); - if (List.of("quit", "exit").contains(userText)) { - break; + switch (userText) { + case "/quit": + case "/exit": + break loop; + case "/context": { + System.out.printf("%d out of %d context tokens used (%d tokens remaining)%n", + conversationTokens.size(), + options.maxTokens(), + options.maxTokens() - conversationTokens.size()); + continue; + } } if (state == null) { - state = model.createNewState(); + state = model.createNewState(BATCH_SIZE); } conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); @@ -126,7 +131,7 @@ static void runInteractive(Llama model, Sampler sampler, Options options) { } static void runInstructOnce(Llama model, Sampler sampler, Options options) { - Llama.State state = model.createNewState(); + Llama.State state = model.createNewState(BATCH_SIZE); ChatFormat chatFormat = new ChatFormat(model.tokenizer()); List promptTokens = new ArrayList<>(); @@ -278,8 +283,20 @@ public static void main(String[] args) throws IOException { final class Parallel { public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) { + if (startInclusive == 0 && endExclusive == 1) { + action.accept(0); + return; + } IntStream.range(startInclusive, endExclusive).parallel().forEach(action); } + + public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) { + if (startInclusive == 0 && endExclusive == 1) { + action.accept(0); + return; + } + LongStream.range(startInclusive, endExclusive).parallel().forEach(action); + } } final class Float16 { @@ -293,7 +310,8 @@ final class Float16 { * e.g. can represent a sequence of quantized floats. */ abstract class FloatTensor { - static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); + static final int VECTOR_BIT_SIZE = Integer.getInteger("llama.VectorBitSize", VectorShape.preferredShape().vectorBitSize()); + static final boolean USE_VECTOR_API = VECTOR_BIT_SIZE != 0; // static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.LITTLE_ENDIAN); // static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); @@ -323,8 +341,9 @@ static byte readByte(MemorySegment memorySegment, long offset) { // Preferred vector size for the fast multiplication routines. // (Apple Silicon) NEON only supports up-to 128bit vectors. - static final VectorSpecies F_SPECIES = FloatVector.SPECIES_PREFERRED.vectorBitSize() == 128 ? FloatVector.SPECIES_128 - : FloatVector.SPECIES_256; + static final VectorSpecies F_SPECIES = USE_VECTOR_API + ? VectorShape.forBitSize(VECTOR_BIT_SIZE).withLanes(float.class) + : null; abstract int size(); @@ -357,6 +376,17 @@ void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) { Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1))); } + void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) { + if (that.length != out.length) { + throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length)); + } + Parallel.parallelForLong(0, dim0 * context, ti -> { + int idxArr = (int) (ti / dim0); + int i = (int) (ti % dim0); + out[idxArr].setFloat(i, dot(i * dim1, that[idxArr], 0, dim1)); + }); + } + @FunctionalInterface interface AggregateFunction { float apply(float acc, float value); @@ -423,7 +453,7 @@ FloatTensor mapInPlace(MapFunction mapFunction) { return mapInPlace(0, size(), mapFunction); } - FloatTensor mapWithIndexInPlace(int thisOffset, int size, MapWithIndexFunction mapWithIndexFunction) { + FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) { int endOffset = thisOffset + size; for (int i = thisOffset; i < endOffset; ++i) { setFloat(i, mapWithIndexFunction.apply(getFloat(i), i)); @@ -557,37 +587,45 @@ private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatT for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) { float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); - var B_SPECIES = ByteVector.SPECIES_128; - var wBytes = ByteVector.fromMemorySegment(B_SPECIES, thiz.memorySegment, blockOffset + Float16.BYTES, + var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); var loBytes = wBytes.and((byte) 0xF).sub((byte) 8); var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8); - if (F_SPECIES.vectorBitSize() == 256) { - var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) - .mul(loBytes.castShape(F_SPECIES, 0)); - var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) - .mul(loBytes.castShape(F_SPECIES, 1)); - var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) - .mul(hiBytes.castShape(F_SPECIES, 0)); - var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) - .mul(hiBytes.castShape(F_SPECIES, 1)); - val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); - } else if (F_SPECIES.vectorBitSize() == 128) { - // This loop cannot be unrolled, why? - for (int i = 0; i < 2; ++i) { - var tmp = i == 0 ? loBytes : hiBytes; - var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()) - .mul(tmp.castShape(F_SPECIES, 0)); - var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()) - .mul(tmp.castShape(F_SPECIES, 1)); - var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()) - .mul(tmp.castShape(F_SPECIES, 2)); - var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()) - .mul(tmp.castShape(F_SPECIES, 3)); + switch (F_SPECIES.vectorBitSize()) { + case 512 -> { + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) + .mul(loBytes.castShape(F_SPECIES, 0)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) + .mul(hiBytes.castShape(F_SPECIES, 0)); + val = sum0.add(sum2).fma(wScale, val); + } + case 256 -> { + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) + .mul(loBytes.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) + .mul(loBytes.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) + .mul(hiBytes.castShape(F_SPECIES, 0)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) + .mul(hiBytes.castShape(F_SPECIES, 1)); val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); } - } else { - throw new UnsupportedOperationException(F_SPECIES.toString()); + case 128 -> { + // This loop cannot be unrolled, why? + for (int i = 0; i < 2; ++i) { + var tmp = i == 0 ? loBytes : hiBytes; + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 2)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 3)); + val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); + } + } + default -> throw new UnsupportedOperationException(F_SPECIES.toString()); } } result += val.reduceLanes(VectorOperators.ADD); @@ -672,36 +710,47 @@ private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatT for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) { float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); - if (F_SPECIES.vectorBitSize() == 256) { - var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, - blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); - var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) - .mul(wBytes.castShape(F_SPECIES, 0)); - var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) - .mul(wBytes.castShape(F_SPECIES, 1)); - var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) - .mul(wBytes.castShape(F_SPECIES, 2)); - var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) - .mul(wBytes.castShape(F_SPECIES, 3)); - val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); - } else if (F_SPECIES.vectorBitSize() == 128) { - VectorSpecies B_128 = ByteVector.SPECIES_128; - // This loop cannot be unrolled, why? - for (int i = 0; i < 2; ++i) { - var wBytes = ByteVector.fromMemorySegment(B_128, thiz.memorySegment, - blockOffset + Float16.BYTES + i * B_128.vectorByteSize(), ByteOrder.LITTLE_ENDIAN); - var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()) + switch (F_SPECIES.vectorBitSize()) { + case 512 -> { + var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, + blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 0)); - var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()) + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 1)); - var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()) + val = sum0.add(sum1).fma(wScale, val); + } + case 256 -> { + var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, + blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 2)); - var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()) + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 3)); val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); } - } else { - throw new UnsupportedOperationException(F_SPECIES.toString()); + case 128 -> { + // This loop cannot be unrolled, why? + for (int i = 0; i < 2; ++i) { + var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, + blockOffset + Float16.BYTES + i * ByteVector.SPECIES_128.vectorByteSize(), + ByteOrder.LITTLE_ENDIAN); + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 2)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 3)); + val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); + } + } + default -> throw new UnsupportedOperationException(F_SPECIES.toString()); } } result += val.reduceLanes(VectorOperators.ADD);