Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: support apple metal framework to compute graph #137

Closed
wants to merge 9 commits into from
24 changes: 24 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ option(RWKV_OPENBLAS "rwkv: use OpenBLAS"
option(RWKV_CUBLAS "rwkv: use cuBLAS" OFF)
option(RWKV_CLBLAST "rwkv: use CLBlast" OFF)
option(RWKV_HIPBLAS "rwkv: use hipBLAS" OFF)
option(RWKV_METAL "rwkv: use Metal" ON)

# Build only shared library without building tests and extras
option(RWKV_STANDALONE "rwkv: build only RWKV library" OFF)
Expand Down Expand Up @@ -80,12 +81,34 @@ if (APPLE AND RWKV_ACCELERATE)
message(STATUS "Accelerate framework found")

add_compile_definitions(GGML_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
add_compile_definitions(ACCELERATE_LAPACK_ILP64)
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
else()
message(WARNING "Accelerate framework not found")
endif()
endif()

if (APPLE AND RWKV_METAL)
find_library(FOUNDATION_LIBRARY Foundation)
find_library(METAL_FRAMEWORK Metal)
find_library(METALKIT_FRAMEWORK MetalKit)

if (METAL_FRAMEWORK AND METALKIT_FRAMEWORK)
message(STATUS "Metal framework found")

set(GGML_METAL_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.h ${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.m)

add_compile_definitions(GGML_USE_METAL)

configure_file(${CMAKE_SOURCE_DIR}/ggml/src/ggml-metal.metal bin/ggml-metal.metal COPYONLY)

set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
else()
message(WARNING "Metal not found")
endif()
endif()

if (RWKV_OPENBLAS)
if (RWKV_STATIC)
set(BLA_STATIC ON)
Expand Down Expand Up @@ -409,6 +432,7 @@ add_library(ggml OBJECT
${CMAKE_SOURCE_DIR}/ggml/src/ggml-alloc.c
${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h
${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml-alloc.h
${GGML_METAL_SOURCES}
${GGML_CUDA_SOURCES}
${GGML_OPENCL_SOURCES})

Expand Down
27 changes: 27 additions & 0 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t

ctx->n_threads = n_threads;

#ifdef GGML_USE_METAL
ctx->ggml_metal_ctx = ggml_metal_init(1);

void * data_ptr = ggml_get_mem_buffer(ctx->model->ggml_ctx);
size_t data_size = ggml_get_mem_size(ctx->model->ggml_ctx);

const size_t max_size = ggml_get_max_tensor_size(ctx->model->ggml_ctx);

ggml_metal_add_buffer(ctx->ggml_metal_ctx, "weight_data", data_ptr, data_size, max_size);

ctx->serial_graph.ggml_metal_ctx = ctx->ggml_metal_ctx;
ctx->sequential_graph.ggml_metal_ctx = ctx->ggml_metal_ctx;
#endif

RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*ctx->model, ctx->serial_graph));

return ctx.release();
Expand All @@ -77,6 +91,13 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32

clone->n_threads = n_threads;

#ifdef GGML_USE_METAL
clone->ggml_metal_ctx = ctx->ggml_metal_ctx;

clone->serial_graph.ggml_metal_ctx = clone->ggml_metal_ctx;
clone->sequential_graph.ggml_metal_ctx = clone->ggml_metal_ctx;
#endif

RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*clone->model, clone->serial_graph));

clone->last_used_sequence_length = 0;
Expand Down Expand Up @@ -151,6 +172,12 @@ void rwkv_free(struct rwkv_context * ctx) {
ggml_free(ctx->sequential_graph.ggml_ctx);
}

#ifdef GGML_USE_METAL
if (ctx->ggml_metal_ctx) {
ggml_metal_free(ctx->ggml_metal_ctx);
}
#endif

std::unique_ptr<struct rwkv_context> rwkv_ctx(ctx);
}

Expand Down
4 changes: 4 additions & 0 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
// Default file version is the latest version.
#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX

#if defined(GGML_USE_METAL)
#include "ggml/src/ggml-metal.h"
#endif

#if defined(__cplusplus)
extern "C" {
#endif
Expand Down
5 changes: 5 additions & 0 deletions rwkv_eval.inc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_
graph.cgraph->n_leafs = graph.post_logits_leafs;
}

#ifdef GGML_USE_METAL
ggml_metal_set_n_cb (graph.ggml_metal_ctx, n_threads);
ggml_metal_graph_compute(graph.ggml_metal_ctx, graph.cgraph.get());
#else
struct ggml_cplan * plan = ggml_graph_plan(graph.cgraph.get(), n_threads);

std::unique_ptr<uint8_t[]> work_data{ new(std::nothrow) uint8_t[plan->work_size] };
Expand All @@ -36,6 +40,7 @@ static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_
ggml_graph_compute(graph.cgraph.get(), plan);

free(plan);
#endif
}

// API function.
Expand Down
22 changes: 22 additions & 0 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ struct rwkv_computation_graph {
// ggml graph counters after the graph was extended with logits tensor.
int post_logits_nodes;
int post_logits_leafs;

#ifdef GGML_USE_METAL
struct ggml_metal_context * ggml_metal_ctx;
#endif
};

// The context holds the model and both serial and sequential computation graphs.
Expand All @@ -50,6 +54,10 @@ struct rwkv_context {

enum rwkv_error_flags last_error;
bool print_errors;

#ifdef GGML_USE_METAL
struct ggml_metal_context * ggml_metal_ctx;
#endif
};

static void rwkv_carry_x(
Expand Down Expand Up @@ -544,6 +552,13 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str
// 2. Create the real ggml context.
graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false);

#ifdef GGML_USE_METAL
void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx);
size_t data_size = ggml_get_mem_size(graph.ggml_ctx);

ggml_metal_add_buffer(graph.ggml_metal_ctx, "serial_computer_buffer", data_ptr, data_size, 0);
#endif

RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph));

return true;
Expand Down Expand Up @@ -685,6 +700,13 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model,
// 2. Create the real ggml context.
graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false);

#ifdef GGML_USE_METAL
void * data_ptr = ggml_get_mem_buffer(graph.ggml_ctx);
size_t data_size = ggml_get_mem_size(graph.ggml_ctx);

ggml_metal_add_buffer(graph.ggml_metal_ctx, "sequential_computer_buffer", data_ptr, data_size, 0);
#endif

RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length));

return true;
Expand Down