Skip to content

Commit

Permalink
llama.android: add field formatChat to control whether to parse speci…
Browse files Browse the repository at this point in the history
…al tokens when send message (ggerganov#11270)
  • Loading branch information
codezjx authored Jan 17, 2025
1 parent 667d728 commit 3edfa7d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
jlong context_pointer,
jlong batch_pointer,
jstring jtext,
jboolean format_chat,
jint n_len
) {

Expand All @@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);

const auto tokens_list = common_tokenize(context, text, 1);
bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special);

auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
Expand All @@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
}

for (auto id : tokens_list) {
LOGi("%s", common_token_to_piece(context, id).c_str());
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
}

common_batch_clear(*batch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class LLamaAndroid {
context: Long,
batch: Long,
text: String,
formatChat: Boolean,
nLen: Int
): Int

Expand Down Expand Up @@ -115,10 +116,10 @@ class LLamaAndroid {
}
}

fun send(message: String): Flow<String> = flow {
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) {
Expand Down

0 comments on commit 3edfa7d

Please sign in to comment.