Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLaMA] uncaught error while running! details: [LayerImpl] Unknown Layer Properties count 2 #2782

Open
dyna-bytes opened this issue Nov 2, 2024 · 7 comments
Assignees

Comments

@dyna-bytes
Copy link

dyna-bytes commented Nov 2, 2024

Hi, I tried to run LLaMa on Galaxy Z FLIP5 (Android14) and got some errors.

  1. ko_KR.UTF_8 error

terminating with uncaught exception of type std::runtime_error: collate_byname<char>::collate_byname failed to construct for ko_KR.UTF-8

I got above error when I first tried to execute LLaMA on nntrainer.
I assumed that error should be related with NDK compiler and just removed the locale setting line from Applications/LLaMA/jni/main.c

   // Setting locale
-  std::locale::global(std::locale("ko_KR.UTF-8"));
+  // std::locale::global(std::locale("ko_KR.UTF-8"));

So the error message was disappeared, but I don't think this is a fundamental solution.

  1. uncaught error while running! details: [LayerImpl] Unknown Layer Properties count 2

The real problem was creating multi_head_attention layer of embedding0 block.

transformer = createTransformerDecoder(i, "embedding0");

The error messaged was generated while runnig 'createAttentionLayer()'.
I added some print line to debug as below.

@@ -399,6 +399,7 @@ std::vector<LayerHandle> createAttentionLayer(const int layer_id, int seq_len,
        withKey("input_layers",
                "layer" + std::to_string(layer_id) + "_attention_flatten")}));
   } else {
+std::cout << __func__ << " " <<  __LINE__ << std::endl;
     layers.push_back(createLayer(
       "multi_head_attention",
       {withKey("name", "layer" + std::to_string(layer_id) + "_attention_out"),

I got the below error message while executing llama but couldn't figure out any reason for Unkown Properties error.

b5q:/data/local/tmp $ ./nntr/nntrainer_llama
createAttentionLayer 402
uncaught error while running! details: [LayerImpl] Unknown Layer Properties count 2

Enviroment:
I used android-ndk-r22b with Ubuntu 20.04 on docker container.
You can download my environment settings on the below url.
https://hub.docker.com/repository/docker/jetpark/nntrainer/general

Please help me solve this problem. Thank you.

@taos-ci
Copy link

taos-ci commented Nov 2, 2024

:octocat: cibot: Thank you for posting issue #2782. The person in charge will reply soon.

@DonghakPark
Copy link
Member

@dyna-bytes Hi, Thank you for reporting issues
we will check the issue you mentioned and inform you of the results.
Did you build the application in our repository using NDK-Build and then run it on device or using Android Studio?

@dyna-bytes
Copy link
Author

@dyna-bytes Hi, Thank you for reporting issues we will check the issue you mentioned and inform you of the results. Did you build the application in our repository using NDK-Build and then run it on device or using Android Studio?

Hi, @DonghakPark
I run it on device using adb shell.

@dyna-bytes dyna-bytes reopened this Nov 6, 2024
@DonghakPark
Copy link
Member

DonghakPark commented Nov 6, 2024

@dyna-bytes

i think error caused by using nntrainer's multi-head attention instead of custom_multi_head_attention in below section.

Have you tried building with meson build -Dplatform=android?

{
    layers.push_back(createLayer(
      "multi_head_attention",
      {withKey("name", "layer" + std::to_string(layer_id) + "_attention_out"),
       withKey("num_heads", std::to_string(NUM_HEADS)),
       withKey("max_timestep", std::to_string(MAX_SEQ_LEN)),
       withKey("disable_bias", "true"),
       withKey("input_layers", {query_name, key_name, value_name})}));
  }

one more, for running Application, you need the following bin file

  std::string weight_path = "./llama_fp16.bin";
  g_model->load(weight_path);

@baek2sm
Copy link
Contributor

baek2sm commented Nov 6, 2024

I also recommend trying out meson build (since the currently released version has been tested with meson build).

To run your own LLaMA model, you need to prepare three additional files.
For tokenizer usage, you will require a vocab.json file and a merges.txt file tailored to your specific model.
For loading pre-trained LLaMA model, you need a weight.bin file compatible with NNTrainer. You can convert your pytorch model's weight file into nntrainer model's weights file using the Applications/LLaMA/PyTorch/weights_converter.py script.

@skykongkong8
Copy link
Member

FYI, nntrainer provides some guidelines to debug remotely on adb shell with lldb-server.
You can follow here to attach debugger(lldb) on your environment.

@dyna-bytes
Copy link
Author

Hi, thank you all for your sincere replies. But I still get the same error message even though I build with meson build -Dplatform=android

I think the build didn't get done correctly.
Please help me debug or correct the build sequence.

This is my build sequence.

  1. rm -r build under nntrainer_root directory for new build.
  2. do meson build -Dplatform=android
  3. do ninja -C build
  4. get flatbuffers version error.
/workspace/nnbuilder/nntrainer/build/nntrainer/compiler/tf_schema_generated.h:11:1: error: static_assert failed due to requirement '1 == 23 && 12 == 5 && 0 == 26' "Non-compatible flatbuffers version included"
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
  1. goto nntrainer/build/tensorflow-2.3.0/tensorflow-lite/include/ and remove flatbuffers/
  2. symlink flatbuffers on my system. ln -s /usr/include/flatbuffers/
  3. goto build/jni and do ndk-build
  4. goto Application/LLama/jni and do ndk-build

I got these files builded,

/workspace/nnbuilder/nntrainer# ninja -C build
ninja: Entering directory `build'
[1/1] Generating ndk-build with a custom command.
[arm64-v8a] Compile++      : nntrainer <= identity_layer.cpp
[arm64-v8a] Compile++      : nntrainer <= reduce_mean_layer.cpp
[arm64-v8a] Compile++      : nntrainer <= reshape_layer.cpp
[arm64-v8a] Compile++      : nntrainer <= dropout.cpp
[arm64-v8a] Compile++      : nntrainer <= layer_context.cpp
[arm64-v8a] Compile++      : nntrainer <= centroid_knn.cpp
[arm64-v8a] Compile++      : nntrainer <= positional_encoding_layer.cpp
[arm64-v8a] Compile++      : nntrainer <= layer_impl.cpp
[arm64-v8a] Compile++      : nntrainer <= optimizer_devel.cpp
[arm64-v8a] Compile++      : nntrainer <= sgd.cpp
[arm64-v8a] Compile++      : nntrainer <= dynamic_training_optimization.cpp
[arm64-v8a] Compile++      : nntrainer <= model_common_properties.cpp
[arm64-v8a] Compile++      : nntrainer <= upsample2d_layer.cpp
[arm64-v8a] Compile++      : nntrainer <= optimizer_context.cpp
[arm64-v8a] Compile++      : nntrainer <= adam.cpp
[arm64-v8a] Compile++      : nntrainer <= tflite_layer.cpp
[arm64-v8a] Compile++      : nntrainer <= blas_interface.cpp
[arm64-v8a] Compile++      : nntrainer <= cache_elem.cpp
[arm64-v8a] Compile++      : nntrainer <= lr_scheduler_cosine.cpp
[arm64-v8a] Compile++      : nntrainer <= model_loader.cpp
[arm64-v8a] Compile++      : nntrainer <= lr_scheduler_constant.cpp
[arm64-v8a] Compile++      : nntrainer <= lr_scheduler_step.cpp
[arm64-v8a] Compile++      : nntrainer <= lr_scheduler_exponential.cpp
[arm64-v8a] Compile++      : nntrainer <= cache_loader.cpp
[arm64-v8a] Compile++      : nntrainer <= optimizer_wrapped.cpp
[arm64-v8a] Compile++      : nntrainer <= tensor_base.cpp
[arm64-v8a] Compile++      : nntrainer <= lazy_tensor.cpp
[arm64-v8a] Compile++      : nntrainer <= cache_pool.cpp
[arm64-v8a] Compile++      : nntrainer <= char_tensor.cpp
[arm64-v8a] Compile++      : nntrainer <= var_grad.cpp
[arm64-v8a] Compile++      : nntrainer <= tensor.cpp
[arm64-v8a] Compile++      : nntrainer <= basic_planner.cpp
[arm64-v8a] Compile++      : nntrainer <= weight.cpp
[arm64-v8a] Compile++      : nntrainer <= swap_device.cpp
[arm64-v8a] Compile++      : nntrainer <= optimized_v1_planner.cpp
[arm64-v8a] Compile++      : nntrainer <= tensor_dim.cpp
[arm64-v8a] Compile++      : nntrainer <= memory_pool.cpp
[arm64-v8a] Compile++      : nntrainer <= float_tensor.cpp
[arm64-v8a] Compile++      : nntrainer <= optimized_v2_planner.cpp
[arm64-v8a] Compile++      : nntrainer <= tensor_pool.cpp
[arm64-v8a] Compile++      : nntrainer <= task_executor.cpp
[arm64-v8a] Compile++      : nntrainer <= optimized_v3_planner.cpp
[arm64-v8a] Compile++      : nntrainer <= manager.cpp
[arm64-v8a] Compile++      : nntrainer <= fp16.cpp
[arm64-v8a] Compile++      : nntrainer <= util_simd.cpp
[arm64-v8a] Compile++      : nntrainer <= neuralnet.cpp
[arm64-v8a] Compile++      : nntrainer <= base_properties.cpp
[arm64-v8a] Compile++      : nntrainer <= profiler.cpp
[arm64-v8a] Compile        : iniparser <= iniparser.c
[arm64-v8a] Compile        : iniparser <= dictionary.c
[arm64-v8a] Install        : libnnstreamer-native.so => /workspace/nnbuilder/nntrainer/build/jni/arm64-v8a/libnnstreamer-native.so
[arm64-v8a] Compile++      : nntrainer <= nntr_threads.cpp
[arm64-v8a] StaticLibrary  : libiniparser.a
[arm64-v8a] Install        : libc++_shared.so => /workspace/nnbuilder/nntrainer/build/jni/arm64-v8a/libc++_shared.so
[arm64-v8a] Compile++      : nntrainer <= util_func.cpp
[arm64-v8a] Compile++      : nntrainer <= ini_wrapper.cpp
[arm64-v8a] Compile++      : nntrainer <= graph_core.cpp
[arm64-v8a] Compile++      : nntrainer <= node_exporter.cpp
[arm64-v8a] Compile++      : nntrainer <= connection.cpp
[arm64-v8a] Compile++      : nntrainer <= network_graph.cpp
[arm64-v8a] SharedLibrary  : libnntrainer.so
[arm64-v8a] SharedLibrary  : libccapi-nntrainer.so
[arm64-v8a] Install        : libccapi-nntrainer.so => /workspace/nnbuilder/nntrainer/build/jni/arm64-v8a/libccapi-nntrainer.so
[arm64-v8a] SharedLibrary  : libcapi-nntrainer.so
[arm64-v8a] Install        : libcapi-nntrainer.so => /workspace/nnbuilder/nntrainer/build/jni/arm64-v8a/libcapi-nntrainer.so
[arm64-v8a] Install        : libnntrainer.so => /workspace/nnbuilder/nntrainer/build/jni/arm64-v8a/libnntrainer.so

and I got these so files too.

/workspace/nnbuilder/nntrainer/build/jni# ndk-build NDK_PROJECT_PATH=./ APP_BUILD_SCRIPT=./Android.mk NDK_APPLICATION_MK=./Application.mk -j
[arm64-v8a] Install        : libnnstreamer-native.so => libs/arm64-v8a/libnnstreamer-native.so
[arm64-v8a] Install        : libc++_shared.so => libs/arm64-v8a/libc++_shared.so
[arm64-v8a] Compile++      : nntrainer <= layer_impl.cpp
[arm64-v8a] Compile++      : nntrainer <= multi_head_attention_layer.cpp
[arm64-v8a] SharedLibrary  : libnntrainer.so
[arm64-v8a] SharedLibrary  : libccapi-nntrainer.so
[arm64-v8a] Install        : libccapi-nntrainer.so => libs/arm64-v8a/libccapi-nntrainer.so
[arm64-v8a] SharedLibrary  : libcapi-nntrainer.so
[arm64-v8a] Install        : libcapi-nntrainer.so => libs/arm64-v8a/libcapi-nntrainer.so
[arm64-v8a] Install        : libnntrainer.so => libs/arm64-v8a/libnntrainer.so

I can also find libcustom_multi_head_attention_layer.so generated.

/workspace/nnbuilder/nntrainer/Applications/LLaMA/jni# ndk-build NDK_PROJECT_PATH=./ APP_BUILD_SCRIPT=./Android.mk NDK_APPLICATION_MK=./Application.mk -j 8
[arm64-v8a] Install        : libccapi-nntrainer.so => libs/arm64-v8a/libccapi-nntrainer.so
[arm64-v8a] Install        : libcustom_multi_head_attention_layer.so => libs/arm64-v8a/libcustom_multi_head_attention_layer.so
[arm64-v8a] Install        : librms_norm_layer.so => libs/arm64-v8a/librms_norm_layer.so
[arm64-v8a] Install        : nntrainer_llama => libs/arm64-v8a/nntrainer_llama
[arm64-v8a] Install        : libswiglu_layer.so => libs/arm64-v8a/libswiglu_layer.so
[arm64-v8a] Install        : libnntrainer.so => libs/arm64-v8a/libnntrainer.so
[arm64-v8a] Install        : libc++_shared.so => libs/arm64-v8a/libc++_shared.so

Please help me figure out or debug what's wrong during build sequence.

Thank you for your time.
Best regards,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants