Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable sparse prediction threshold #7

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import concurrent.futures
import dataclasses
import enum
import faulthandler
import functools
Expand Down Expand Up @@ -138,6 +139,28 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
# hparams loading
#

@dataclass
class PredictorParams:
sparse_threshold: float | None = None

@staticmethod
def loadPredictorJson(model: LazyModel, config_path: Path) -> PredictorParams:
config = json.load(open(config_path))
return PredictorParams(
sparse_threshold = config.get("sparse_threshold"),
)

@staticmethod
def load(model_plus: ModelPlus) -> PredictorParams:
config_path = model_plus.paths[0].parent / "config.json"

if config_path.exists():
params = PredictorParams.loadPredictorJson(model_plus.model, config_path)
else:
params = PredictorParams()

return params

@dataclass
class Params:
n_vocab: int
Expand All @@ -160,6 +183,9 @@ class Params:
# path to the directory containing the model files
path_model: Path | None = None

# MLP predictor parameters
predictor_params: PredictorParams = dataclasses.field(default_factory=PredictorParams)

@staticmethod
def guessed(model: LazyModel) -> Params:
# try transformer naming first
Expand Down Expand Up @@ -843,6 +869,9 @@ def add_meta_arch(self, params: Params) -> None:
if params.ftype is not None:
self.gguf.add_file_type(params.ftype)

if params.predictor_params.sparse_threshold is not None:
self.gguf.add_sparse_threshold(params.predictor_params.sparse_threshold)

def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = []
scores = []
Expand Down Expand Up @@ -1181,10 +1210,13 @@ def main(args_in: list[str] | None = None) -> None:

if not args.vocab_only:
model_plus = load_some_model(args.model)
params = Params.load(model_plus)
mlp_predictor_plus = load_mlp_model(args.mlp_model)
params.predictor_params = PredictorParams.load(mlp_predictor_plus)
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
params = Params.load(model_plus)

if args.dump:
do_dump_model(model_plus)
Expand All @@ -1193,7 +1225,6 @@ def main(args_in: list[str] | None = None) -> None:
if args.bigendian:
endianess = gguf.GGUFEndian.BIG

params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
Expand Down
21 changes: 10 additions & 11 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@
// max batch size to use MMQ kernels when tensor cores are available
#define MMQ_MAX_BATCH_SIZE 32

__constant__ float dev_sparse_threshold;

#if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300

Expand Down Expand Up @@ -4483,7 +4485,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__
// printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095));
// }
// int id = row;
if (idx[id] < 0.0f) {
if (idx[id] < dev_sparse_threshold) {
return;
}

Expand Down Expand Up @@ -4552,12 +4554,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
return;
}
int id = lst[row];
// int id = row;
// if (idx[id] < 0.0f) {
// return;
// }
const int bid = blockIdx.y;
// if (bid == 0) global_lock = 0;

extern __shared__ float shared_dst[]; // TODO:dynamic

Expand All @@ -4578,7 +4575,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
// __syncthreads();
for (int col_id = 0; col_id < src1_ncols; col_id++) {
__syncthreads();
if (loop_idx[id] < 0.0f) {
if (loop_idx[id] < dev_sparse_threshold) {
loop_dst += ncols;
loop_idx += src1_ne0;
loop_y += src1_ne0;
Expand Down Expand Up @@ -4640,7 +4637,7 @@ static __global__ void dequantize_axpy_sparse(const void * __restrict__ vx, cons
return;
}
int id = lst[row];
if (idx[id] < 0.0f) {
if (idx[id] < dev_sparse_threshold) {
return;
}

Expand Down Expand Up @@ -4689,8 +4686,7 @@ static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ v
return;
}
int id = lst[row];
// int id = row;
if (idx[id] < 0.0f) {
if (idx[id] < dev_sparse_threshold) {
return;
}

Expand Down Expand Up @@ -4782,7 +4778,7 @@ static __global__ void dequantize_mul_mat_batch_sparse(const void * __restrict__
{
__syncthreads();
tmp = 0.0f;
if (loop_idx[id] < 0.0f)
if (loop_idx[id] < dev_sparse_threshold)
{
loop_dst += dst_ne0;
loop_idx += dst_ne0;
Expand Down Expand Up @@ -9618,3 +9614,6 @@ ggml_backend_t ggml_backend_cuda_init() {
return cuda_backend;
}

void ggml_cuda_set_device_constants(float sparse_pred_threshold) {
CUDA_CHECK(cudaMemcpyToSymbol(dev_sparse_threshold, &sparse_pred_threshold, sizeof(float)));
}
2 changes: 2 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ GGML_API int ggml_cuda_get_device_count(void);
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API size_t ggml_cuda_get_free_memory(int device);

GGML_API void ggml_cuda_set_device_constants(float sparse_pred_threshold);

// backend API
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use

Expand Down
17 changes: 9 additions & 8 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14059,6 +14059,8 @@ static void ggml_compute_forward_mul_mat_sparse(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

const float threshold = sparse_pred_threshold;

GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
Expand Down Expand Up @@ -14262,7 +14264,7 @@ static void ggml_compute_forward_mul_mat_sparse(
float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));

// if (ffdata[ir0] <= 0.0f) {
if (gid[ir0] == 1 || ffdata[ir0] < -0.0f) {
if (gid[ir0] == 1 || ffdata[ir0] < threshold) {
dst_col[ir0] = 0;
continue;
}
Expand Down Expand Up @@ -14413,11 +14415,6 @@ static void ggml_compute_forward_mul_mat_axpy_dense(
const int ir0 = atomic_fetch_add(params->aic, dr);
for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) {
if (ir1 >= nr) break;
// if (gid[ir1] == 1)
// continue;
// if (idx[ir1] < 0.0f)
// continue;
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]);
}
if (ir0 + dr >= nr)
Expand Down Expand Up @@ -14482,6 +14479,8 @@ static void ggml_compute_forward_mul_mat_axpy(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

const float threshold = sparse_pred_threshold;

// GGML_ASSERT(ne0 == ne01);
// GGML_ASSERT(ne1 == ne11);
// GGML_ASSERT(ne2 == ne12);
Expand Down Expand Up @@ -14569,7 +14568,7 @@ static void ggml_compute_forward_mul_mat_axpy(
if (gid[ir1] == 1) {
continue;
}
if (idx[ir1] < -0.0f)
if (idx[ir1] < threshold)
continue;
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]);
Expand Down Expand Up @@ -14632,6 +14631,8 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

const float threshold = sparse_pred_threshold;

// GGML_ASSERT(ne0 == ne01);
// GGML_ASSERT(ne1 == ne11);
// GGML_ASSERT(ne2 == ne12);
Expand Down Expand Up @@ -14713,7 +14714,7 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
break;
if (gid[ir1] == 1)
continue;
if (idx[ir1] < 0.0f)
if (idx[ir1] < threshold)
continue;
int bid = ir1 / QK8_0;
int qsid = ir1 % QK8_0;
Expand Down
6 changes: 6 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2196,6 +2196,12 @@ extern "C" {
GGML_API int ggml_cpu_has_ssse3 (void);
GGML_API int ggml_cpu_has_vsx (void);

//
// global variables
//
// TODO: these should be moved to the context
extern float sparse_pred_threshold;

//
// Internal types and functions exposed for tests and benchmarks
//
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class Tokenizer:
ADD_EOS = "tokenizer.ggml.add_eos_token"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"

class PowerInfer:
SPARSE_THRESHOLD = "powerinfer.sparse_threshold"


#
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ def add_add_bos_token(self, value: bool) -> None:
def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value)

def add_sparse_threshold(self, value: float) -> None:
self.add_float32(Keys.PowerInfer.SPARSE_THRESHOLD, value)

def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
Expand Down
50 changes: 30 additions & 20 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@

#define LLAMA_MAX_NODES 4096

//
// global variables
//

// sparsity threshold for sparse matrix multiplication prediction
float sparse_pred_threshold = 0.;

//
// logging
//
Expand Down Expand Up @@ -257,6 +264,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,

LLM_KV_SPARSE_THRESHOLD,
};

static std::map<llm_kv, std::string> LLM_KV_NAMES = {
Expand Down Expand Up @@ -305,6 +314,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },

{ LLM_KV_SPARSE_THRESHOLD, "powerinfer.sparse_threshold" },
};

struct LLM_KV {
Expand Down Expand Up @@ -1150,6 +1161,9 @@ struct llama_hparams {

float f_clamp_kqv;
float f_max_alibi_bias;

// sparse predictor threshold if sparse inference is enabled
float sparse_pred_threshold = atof(getenv("LLAMA_SPARSE_PRED_THRESHOLD") ?: "0.0");

bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
Expand Down Expand Up @@ -2220,6 +2234,11 @@ static void llm_load_hparams(
// gpt-j n_rot = rotary_dim
}

if (gguf_get_sparse_deriv(ctx)) {
// read sparse threshold override if sparse deriv is enabled
GGUF_GET_KEY(ctx, hparams.sparse_pred_threshold, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_SPARSE_THRESHOLD));
}

// arch-specific KVs
switch (model.arch) {
case LLM_ARCH_LLAMA:
Expand Down Expand Up @@ -2607,6 +2626,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }

// sparse inference
LLAMA_LOG_INFO("%s: sparse_pred_threshold = %.2f\n", __func__, hparams.sparse_pred_threshold);
}


Expand Down Expand Up @@ -2808,7 +2830,7 @@ struct llama_augmentation_model_loader {
return NULL;
}
// allocate and copy selected weights to gpu
#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUBLAS
int64_t row_len = src->ne[0];
int64_t gpu_rows = gpu_bucket->ne[0];
if (gpu_rows == 0)
Expand Down Expand Up @@ -2841,10 +2863,9 @@ struct llama_augmentation_model_loader {
ggml_set_no_alloc(aux_ctx, false);

return gpu_dst;
#else
printf("As you do not support CUDA. Split to GPU is not allowed.\n");
#else
return NULL;
#endif
#endif
}

void slice_ffn_mat_to_gpu(llama_layer & layer) {
Expand Down Expand Up @@ -2882,22 +2903,11 @@ struct llama_augmentation_model_loader {
const int64_t t_start_aug_us = ggml_time_us();
std::vector<uint8_t> work_buffer;

// transpose ffn_down to use axpy
// ggml_cgraph * tmp_transpose_gf = ggml_new_graph(aux_ctx);
// for (llama_layer &model_layer : model -> layers) {
// // gpu_w2 transpose load
// ggml_tensor * ffn_down_t = ggml_cont(aux_ctx, ggml_transpose(aux_ctx, model_layer.ffn_down));
// ggml_build_forward_expand(tmp_transpose_gf, ffn_down_t);
// model_layer.ffn_down_t = ffn_down_t;
// LLAMA_LOG_INFO(".");
// }
// ggml_graph_compute_helper(work_buffer, tmp_transpose_gf, 2);
// for (llama_layer &model_layer : model -> layers) {
// model_layer.ffn_down_t->op = GGML_OP_NONE;
// model_layer.ffn_down_t->src[0] = NULL;
// model_layer.ffn_down_t->src[1] = NULL;
// model_layer.ffn_down_t->src[2] = NULL;
// }
// Set sparsity threshold via global virables
sparse_pred_threshold = model->hparams.sparse_pred_threshold;
#if defined (GGML_USE_CUBLAS)
ggml_cuda_set_device_constants(model->hparams.sparse_pred_threshold);
#endif

// load gpu_idx and slice mat to gpu
for (llama_layer &model_layer : model -> layers) {
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy==1.24.4
sentencepiece==0.1.98
gguf>=0.1.0
-e ./gguf-py
Loading