From 56c665f5003e62f38ef0fc1d57c0524ae38de9f2 Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Mon, 9 Oct 2023 15:23:22 +0800 Subject: [PATCH 1/8] feat: add metal setting --- CMakeLists.txt | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 10305297..818bb172 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,7 @@ option(RWKV_OPENBLAS "rwkv: use OpenBLAS" option(RWKV_CUBLAS "rwkv: use cuBLAS" OFF) option(RWKV_CLBLAST "rwkv: use CLBlast" OFF) option(RWKV_HIPBLAS "rwkv: use hipBLAS" OFF) +option(RWKV_METAL "rwkv: use Metal" ON) # Build only shared library without building tests and extras option(RWKV_STANDALONE "rwkv: build only RWKV library" OFF) @@ -86,6 +87,25 @@ if (APPLE AND RWKV_ACCELERATE) endif() endif() +if (APPLE AND RWKV_METAL) + find_library(METAL_FRAMEWORK Metal) + find_library(METALKIT_FRAMEWORK MetalKit) + if (METAL_FRAMEWORK AND METALKIT_FRAMEWORK) + message(STATUS "Metal framework found") + + set(GGML_HEADERS_METAL ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.h) + set(GGML_SOURCES_METAL ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.m) + + add_compile_definitions(GGML_USE_METAL) + + configure_file(${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.metal bin/ggml-metal.metal COPYONLY) + + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + else() + message(WARNING "Metal not found") + endif() +endif() + if (RWKV_OPENBLAS) if (RWKV_STATIC) set(BLA_STATIC ON) From f9f3a81feae8edbed9255cd4b9a590c8f5450c4e Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Mon, 9 Oct 2023 15:31:20 +0800 Subject: [PATCH 2/8] feat: update new lapack for Accelerate --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 818bb172..206fd580 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,6 +81,8 @@ if (APPLE AND RWKV_ACCELERATE) message(STATUS "Accelerate framework found") add_compile_definitions(GGML_USE_ACCELERATE) + add_compile_definitions(ACCELERATE_NEW_LAPACK) + add_compile_definitions(ACCELERATE_LAPACK_ILP64) set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) else() message(WARNING "Accelerate framework not found") From 0d9e23d8574e088afa67a9b00f402f2d0a5a9a5d Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Mon, 9 Oct 2023 22:16:10 +0800 Subject: [PATCH 3/8] feat: add foundation library and add metal source into library --- CMakeLists.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 206fd580..d7ca10cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,19 +90,20 @@ if (APPLE AND RWKV_ACCELERATE) endif() if (APPLE AND RWKV_METAL) + find_library(FOUNDATION_LIBRARY Foundation) find_library(METAL_FRAMEWORK Metal) find_library(METALKIT_FRAMEWORK MetalKit) + if (METAL_FRAMEWORK AND METALKIT_FRAMEWORK) message(STATUS "Metal framework found") - set(GGML_HEADERS_METAL ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.h) - set(GGML_SOURCES_METAL ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.m) + set(GGML_METAL_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.h ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.m) add_compile_definitions(GGML_USE_METAL) configure_file(${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.metal bin/ggml-metal.metal COPYONLY) - set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) else() message(WARNING "Metal not found") endif() @@ -419,6 +420,7 @@ add_library(ggml OBJECT ${CMAKE_SOURCE_DIR}/ggml/src/ggml-alloc.c ${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h ${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml-alloc.h + ${GGML_METAL_SOURCES} ${GGML_CUDA_SOURCES} ${GGML_OPENCL_SOURCES}) From d73d7a3ccf6b1b44faacbda23f217049326d0301 Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Thu, 9 Nov 2023 12:54:45 +0800 Subject: [PATCH 4/8] add metal functions --- rwkv.h | 4 ++++ rwkv_eval.inc | 5 +++++ rwkv_graph.inc | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/rwkv.h b/rwkv.h index 40b9266c..6d7f2cde 100644 --- a/rwkv.h +++ b/rwkv.h @@ -29,6 +29,10 @@ // Default file version is the latest version. #define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX +#if defined(GGML_USE_METAL) +#include "ggml/src/ggml-metal.h" +#endif + #if defined(__cplusplus) extern "C" { #endif diff --git a/rwkv_eval.inc b/rwkv_eval.inc index 37f3a9cd..33257c8f 100644 --- a/rwkv_eval.inc +++ b/rwkv_eval.inc @@ -28,6 +28,10 @@ static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_ graph.cgraph->n_leafs = graph.post_logits_leafs; } +#ifdef GGML_USE_METAL + ggml_metal_set_n_cb (graph.ggml_metal_ctx, n_threads); + ggml_metal_graph_compute(graph.ggml_metal_ctx, graph.cgraph.get()); +#else struct ggml_cplan * plan = ggml_graph_plan(graph.cgraph.get(), n_threads); std::unique_ptr work_data{ new(std::nothrow) uint8_t[plan->work_size] }; @@ -36,6 +40,7 @@ static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_ ggml_graph_compute(graph.cgraph.get(), plan); free(plan); +#endif } // API function. diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 095e5038..96c433b4 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -30,6 +30,10 @@ struct rwkv_computation_graph { // ggml graph counters after the graph was extended with logits tensor. int post_logits_nodes; int post_logits_leafs; + +#ifdef GGML_USE_METAL + struct ggml_metal_context * ggml_metal_ctx; +#endif }; // The context holds the model and both serial and sequential computation graphs. @@ -314,6 +318,12 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str graph.ggml_ctx = NULL; } +#ifdef GGML_USE_METAL + if (graph.ggml_metal_ctx) { + ggml_metal_free(graph.ggml_metal_ctx); + } +#endif + // 1. Measure the space required for the ggml context. graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); @@ -336,6 +346,17 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str // 2. Create the real ggml context. graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); +#ifdef GGML_USE_METAL + graph.ggml_metal_ctx = ggml_metal_init(1); + + void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx); + size_t data_size = ggml_get_mem_size(graph.ggml_ctx); + + const size_t max_size = ggml_get_max_tensor_size(graph.ggml_ctx); + + ggml_metal_add_buffer(graph.ggml_metal_ctx, "data", data_ptr, data_size, max_size); +#endif + RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); return true; @@ -437,6 +458,12 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, graph.ggml_ctx = NULL; } +#ifdef GGML_USE_METAL + if (graph.ggml_metal_ctx) { + ggml_metal_free(graph.ggml_metal_ctx); + } +#endif + // 1. Measure the space required for the ggml context. graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); @@ -459,6 +486,17 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, // 2. Create the real ggml context. graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); +#ifdef GGML_USE_METAL + graph.ggml_metal_ctx = ggml_metal_init(1); + + void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx); + size_t data_size = ggml_get_mem_size(graph.ggml_ctx); + + const size_t max_size = ggml_get_max_tensor_size(graph.ggml_ctx); + + ggml_metal_add_buffer(graph.ggml_metal_ctx, "data", data_ptr, data_size, max_size); +#endif + RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); return true; From 170b113a489f19ba9429e3a33d8b54d1e9193912 Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Fri, 10 Nov 2023 17:02:29 +0800 Subject: [PATCH 5/8] fix: init metal context once, add also binding model weight data into metal --- rwkv.cpp | 20 ++++++++++++++++++++ rwkv_graph.inc | 28 ++++++---------------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index f7406bf5..451a326f 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -62,6 +62,20 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t ctx->n_threads = n_threads; +#ifdef GGML_USE_METAL + ctx->ggml_metal_ctx = ggml_metal_init(1); + + void * data_ptr = ggml_get_mem_buffer(ctx->model->ggml_ctx); + size_t data_size = ggml_get_mem_size(ctx->model->ggml_ctx); + + const size_t max_size = ggml_get_max_tensor_size(ctx->model->ggml_ctx); + + ggml_metal_add_buffer(ctx->ggml_metal_ctx, "data", data_ptr, data_size, max_size); + + ctx->serial_graph.ggml_metal_ctx = ctx->ggml_metal_ctx; + ctx->sequential_graph.ggml_metal_ctx = ctx->ggml_metal_ctx; +#endif + RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*ctx->model, ctx->serial_graph)); return ctx.release(); @@ -147,6 +161,12 @@ void rwkv_free(struct rwkv_context * ctx) { ggml_free(ctx->sequential_graph.ggml_ctx); } +#ifdef GGML_USE_METAL + if (ctx->ggml_metal_ctx) { + ggml_metal_free(ctx->ggml_metal_ctx); + } +#endif + std::unique_ptr rwkv_ctx(ctx); } diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 96c433b4..27b70520 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -51,6 +51,10 @@ struct rwkv_context { enum rwkv_error_flags last_error; bool print_errors; + +#ifdef GGML_USE_METAL + struct ggml_metal_context * ggml_metal_ctx; +#endif }; static void rwkv_carry_x( @@ -318,12 +322,6 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str graph.ggml_ctx = NULL; } -#ifdef GGML_USE_METAL - if (graph.ggml_metal_ctx) { - ggml_metal_free(graph.ggml_metal_ctx); - } -#endif - // 1. Measure the space required for the ggml context. graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); @@ -347,14 +345,10 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); #ifdef GGML_USE_METAL - graph.ggml_metal_ctx = ggml_metal_init(1); - void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx); size_t data_size = ggml_get_mem_size(graph.ggml_ctx); - const size_t max_size = ggml_get_max_tensor_size(graph.ggml_ctx); - - ggml_metal_add_buffer(graph.ggml_metal_ctx, "data", data_ptr, data_size, max_size); + ggml_metal_add_buffer(graph.ggml_metal_ctx, "serial_buffer", data_ptr, data_size, 0); #endif RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); @@ -458,12 +452,6 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, graph.ggml_ctx = NULL; } -#ifdef GGML_USE_METAL - if (graph.ggml_metal_ctx) { - ggml_metal_free(graph.ggml_metal_ctx); - } -#endif - // 1. Measure the space required for the ggml context. graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); @@ -487,14 +475,10 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); #ifdef GGML_USE_METAL - graph.ggml_metal_ctx = ggml_metal_init(1); - void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx); size_t data_size = ggml_get_mem_size(graph.ggml_ctx); - const size_t max_size = ggml_get_max_tensor_size(graph.ggml_ctx); - - ggml_metal_add_buffer(graph.ggml_metal_ctx, "data", data_ptr, data_size, max_size); + ggml_metal_add_buffer(graph.ggml_metal_ctx, "sequential_buffer", data_ptr, data_size, 0); #endif RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); From 05386f7c32e30b659969a0cf8ce4e31bad0744ef Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Fri, 10 Nov 2023 22:38:45 +0800 Subject: [PATCH 6/8] feat: update buffer name --- ggml | 2 +- rwkv.cpp | 2 +- rwkv_graph.inc | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml b/ggml index d925ed7a..a0fec8ff 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit d925ed7a96767192d422a97645f08ad86d5cc6f0 +Subproject commit a0fec8ffa8b64fe67face8cc7d4af3dac370965d diff --git a/rwkv.cpp b/rwkv.cpp index 451a326f..e0310e4a 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -70,7 +70,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t const size_t max_size = ggml_get_max_tensor_size(ctx->model->ggml_ctx); - ggml_metal_add_buffer(ctx->ggml_metal_ctx, "data", data_ptr, data_size, max_size); + ggml_metal_add_buffer(ctx->ggml_metal_ctx, "weight_data", data_ptr, data_size, max_size); ctx->serial_graph.ggml_metal_ctx = ctx->ggml_metal_ctx; ctx->sequential_graph.ggml_metal_ctx = ctx->ggml_metal_ctx; diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 27b70520..0653e655 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -348,7 +348,7 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx); size_t data_size = ggml_get_mem_size(graph.ggml_ctx); - ggml_metal_add_buffer(graph.ggml_metal_ctx, "serial_buffer", data_ptr, data_size, 0); + ggml_metal_add_buffer(graph.ggml_metal_ctx, "serial_computer_buffer", data_ptr, data_size, 0); #endif RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); @@ -478,7 +478,7 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx); size_t data_size = ggml_get_mem_size(graph.ggml_ctx); - ggml_metal_add_buffer(graph.ggml_metal_ctx, "sequential_buffer", data_ptr, data_size, 0); + ggml_metal_add_buffer(graph.ggml_metal_ctx, "sequential_computer_buffer", data_ptr, data_size, 0); #endif RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); From 8e9c00db193a06663e13ee84d01184838717244a Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Fri, 10 Nov 2023 22:44:50 +0800 Subject: [PATCH 7/8] feat: update clone with metal setting --- rwkv.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rwkv.cpp b/rwkv.cpp index e0310e4a..03abd192 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -91,6 +91,13 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32 clone->n_threads = n_threads; +#ifdef GGML_USE_METAL + clone->ggml_metal_ctx = ctx->ggml_metal_ctx; + + clone->serial_graph.ggml_metal_ctx = clone->ggml_metal_ctx; + clone->sequential_graph.ggml_metal_ctx = clone->ggml_metal_ctx; +#endif + RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*clone->model, clone->serial_graph)); clone->last_used_sequence_length = 0; From 5c43a05fbbf4deabda5ceed27ba755d7bddd3934 Mon Sep 17 00:00:00 2001 From: Jason Fan Date: Fri, 10 Nov 2023 22:50:10 +0800 Subject: [PATCH 8/8] revert ggml verision --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index a0fec8ff..d925ed7a 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit a0fec8ffa8b64fe67face8cc7d4af3dac370965d +Subproject commit d925ed7a96767192d422a97645f08ad86d5cc6f0