Skip to content

Commit

Permalink
Use latest version of the Llama3.java code
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed Nov 14, 2024
1 parent da4251b commit 1c1a2d1
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {

private InferenceResponse runInference(Llama model, Sampler sampler, Llama3.Options options,
List<ChatFormat.Message> messages) {
Llama.State state = model.createNewState();
Llama.State state = model.createNewState(Llama3.BATCH_SIZE);
ChatFormat chatFormat = new ChatFormat(model.tokenizer());

List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;

import java.io.IOException;
Expand Down Expand Up @@ -84,7 +85,7 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
private void runInference(Llama model, Sampler sampler, Llama3.Options options,
List<ChatFormat.Message> messages,
StreamingResponseHandler<AiMessage> handler) {
Llama.State state = model.createNewState();
Llama.State state = model.createNewState(BATCH_SIZE);
ChatFormat chatFormat = new ChatFormat(model.tokenizer());

List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static PartialModel preLoadGGUF(String modelPath) {
* No checksum/hash is checked for performance reasons.
*/
public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
PartialModel preLoaded = AOT.PRELOADED_GGUF;
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
if (preLoaded == null) {
return null; // no pre-loaded model stored
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private void loadModelImpl(FileChannel fileChannel) throws IOException {
// gguf_tensor_info_t tensor_infos[header.tensor_count];
this.tensorInfos = HashMap.newHashMap(tensorCount);
for (int i = 0; i < tensorCount; ++i) {
GGUFTensorInfo ti = readTensorInfo(fileChannel);
GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
assert !tensorInfos.containsKey(ti.name);
tensorInfos.put(ti.name, ti);
}
Expand Down Expand Up @@ -156,7 +156,7 @@ private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
return GGMLType.fromId(ggmlTypeId);
}

private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
// The name of the tensor. It is a standard GGUF string, with the caveat that
// it must be at most 64 bytes long.
String name = readString(fileChannel); // gguf_string_t name;
Expand All @@ -180,7 +180,7 @@ private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOExceptio
// Must be a multiple of `ALIGNMENT`.
long offset = readLong(fileChannel); // uint64_t offset;
assert offset % getAlignment() == 0;
return new GGUFTensorInfo(name, dimensions, ggmlType, offset);
return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset);
}

private String readString(FileChannel fileChannel) throws IOException {
Expand Down
Loading

0 comments on commit 1c1a2d1

Please sign in to comment.