diff --git a/Makefile b/Makefile index 05062ed3f9ca6..9a396f1e491ee 100644 --- a/Makefile +++ b/Makefile @@ -48,6 +48,12 @@ LDFLAGS = FASTCFLAGS = $(subst -O3,-Ofast,$(CFLAGS)) FASTCXXFLAGS = $(subst -O3,-Ofast,$(CXXFLAGS)) +# MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon +# MK_CFLAGS = -std=c11 -fPIC +# MK_CXXFLAGS = -std=c++17 -fPIC +# MK_NVCCFLAGS = -std=c++11 + + # these are used on windows, to build some libraries with extra old device compatibility SIMPLECFLAGS = FULLCFLAGS = @@ -69,10 +75,117 @@ OBJS_FULL += ggml-alloc.o ggml-aarch64.o ggml-quants.o unicode.o unicode-data.o OBJS_SIMPLE += ggml-alloc.o ggml-aarch64.o ggml-quants_noavx2.o unicode.o unicode-data.o sgemm_noavx2.o common.o sampling.o grammar-parser.o OBJS_FAILSAFE += ggml-alloc.o ggml-aarch64.o ggml-quants_failsafe.o unicode.o unicode-data.o sgemm_failsafe.o common.o sampling.o grammar-parser.o + #lets try enabling everything CFLAGS += -pthread -s -Wno-deprecated -Wno-deprecated-declarations -Wno-unused-variable CXXFLAGS += -pthread -s -Wno-multichar -Wno-write-strings -Wno-deprecated -Wno-deprecated-declarations -Wno-unused-variable +RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +and on macOS its availability depends on enabling Darwin extensions +similarly on DragonFly, enabling BSD extensions is necessary +# ifeq ($(UNAME_S),Darwin) + # MK_CPPFLAGS += -D_DARWIN_C_SOURCE +# endif +# ifeq ($(UNAME_S),DragonFly) + # MK_CPPFLAGS += -D__BSD_VISIBLE +# endif + +alloca is a non-standard interface that is not visible on BSDs when +POSIX conformance is specified, but not all of them provide a clean way +to enable it in such cases +# ifeq ($(UNAME_S),FreeBSD) + # MK_CPPFLAGS += -D__BSD_VISIBLE +# endif +# ifeq ($(UNAME_S),NetBSD) + # MK_CPPFLAGS += -D_NETBSD_SOURCE +# endif +# ifeq ($(UNAME_S),OpenBSD) + # MK_CPPFLAGS += -D_BSD_SOURCE +# endif + +# ifdef GGML_SCHED_MAX_COPIES + # MK_CPPFLAGS += -DGGML_SCHED_MAX_COPIES=$(GGML_SCHED_MAX_COPIES) +# endif + +# ifdef LLAMA_DEBUG + # MK_CFLAGS += -O0 -g + # MK_CXXFLAGS += -O0 -g + # MK_LDFLAGS += -g + # MK_NVCCFLAGS += -O0 -g + + # ifeq ($(UNAME_S),Linux) + # MK_CPPFLAGS += -D_GLIBCXX_ASSERTIONS + # endif +# else + # MK_CPPFLAGS += -DNDEBUG + # MK_CFLAGS += -O3 + # MK_CXXFLAGS += -O3 + # MK_NVCCFLAGS += -O3 +# endif + +# ifdef LLAMA_SANITIZE_THREAD + # MK_CFLAGS += -fsanitize=thread -g + # MK_CXXFLAGS += -fsanitize=thread -g + # MK_LDFLAGS += -fsanitize=thread -g +# endif + +# ifdef LLAMA_SANITIZE_ADDRESS + # MK_CFLAGS += -fsanitize=address -fno-omit-frame-pointer -g + # MK_CXXFLAGS += -fsanitize=address -fno-omit-frame-pointer -g + # MK_LDFLAGS += -fsanitize=address -fno-omit-frame-pointer -g +# endif + +# ifdef LLAMA_SANITIZE_UNDEFINED + # MK_CFLAGS += -fsanitize=undefined -g + # MK_CXXFLAGS += -fsanitize=undefined -g + # MK_LDFLAGS += -fsanitize=undefined -g +# endif + +# ifdef LLAMA_SERVER_VERBOSE + # MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) +# endif + +# ifdef LLAMA_SERVER_SSL + # MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT + # MK_LDFLAGS += -lssl -lcrypto +# endif + +# ifdef LLAMA_DISABLE_LOGS + # MK_CPPFLAGS += -DLOG_DISABLE_LOGS +# endif # LLAMA_DISABLE_LOGS + +warnings +# WARN_FLAGS = \ + # -Wall \ + # -Wextra \ + # -Wpedantic \ + # -Wcast-qual \ + # -Wno-unused-function + +# MK_CFLAGS += \ + # $(WARN_FLAGS) \ + # -Wshadow \ + # -Wstrict-prototypes \ + # -Wpointer-arith \ + # -Wmissing-prototypes \ + # -Werror=implicit-int \ + # -Werror=implicit-function-declaration + +# MK_CXXFLAGS += \ + # $(WARN_FLAGS) \ + # -Wmissing-declarations \ + # -Wmissing-noreturn + +# ifeq ($(LLAMA_FATAL_WARNINGS),1) + # MK_CFLAGS += -Werror + # MK_CXXFLAGS += -Werror +# endif + +this version of Apple ld64 is buggy +# ifneq '' '$(findstring dyld-1015.7,$(shell $(CC) $(LDFLAGS) -Wl,-v 2>&1))' + # MK_CPPFLAGS += -DHAVE_BUGGY_APPLE_LINKER +# endif + # OS specific # TODO: support Windows ifeq ($(UNAME_S),Linux) @@ -371,6 +484,74 @@ ifeq ($(OS),Windows_NT) ifdef LLAMA_HIPBLAS HIPBLAS_BUILD = $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o $@.dll $(HIPLDFLAGS) $(LDFLAGS) endif + +# endif # GGML_NO_ACCELERATE + +# ifdef GGML_MUSA + # CC := clang + # CXX := clang++ + # GGML_CUDA := 1 + # MK_CPPFLAGS += -DGGML_USE_MUSA +# endif + +# ifndef GGML_NO_OPENMP + # MK_CPPFLAGS += -DGGML_USE_OPENMP + # MK_CFLAGS += -fopenmp + # MK_CXXFLAGS += -fopenmp + # ifdef GGML_MUSA + # MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp + # MK_LDFLAGS += -L/usr/lib/llvm-10/lib + # endif # GGML_MUSA +# endif # GGML_NO_OPENMP + +# ifdef GGML_OPENBLAS + # MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas) + # MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas) + # MK_LDFLAGS += $(shell pkg-config --libs openblas) + # OBJ_GGML += ggml/src/ggml-blas.o +# endif # GGML_OPENBLAS + +# ifdef GGML_OPENBLAS64 + # MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas64) + # MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas64) + # MK_LDFLAGS += $(shell pkg-config --libs openblas64) + # OBJ_GGML += ggml/src/ggml-blas.o +# endif # GGML_OPENBLAS64 + +# ifdef GGML_BLIS + # MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis + # MK_LDFLAGS += -lblis -L/usr/local/lib + # OBJ_GGML += ggml/src/ggml-blas.o +# endif # GGML_BLIS + +# ifdef GGML_NVPL + # MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas + # MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp + # OBJ_GGML += ggml/src/ggml-blas.o +# endif # GGML_NVPL + +# OBJ_GGML += ggml/src/iqk/iqk_quantize.o +# ifndef GGML_NO_IQKMULMAT + # MK_CPPFLAGS += -DGGML_USE_IQK_MULMAT + # OBJ_GGML += ggml/src/iqk/iqk_mul_mat.o +# endif + +# ifndef GGML_NO_LLAMAFILE + # MK_CPPFLAGS += -DGGML_USE_LLAMAFILE + # OBJ_GGML += ggml/src/llamafile/sgemm.o +# endif + +# ifdef GGML_RPC + # MK_CPPFLAGS += -DGGML_USE_RPC + # OBJ_GGML += ggml/src/ggml-rpc.o +# endif # GGML_RPC + +# OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu)) +# OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu)) + +# ifdef GGML_CUDA_FA_ALL_QUANTS + # OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*.cu)) + else DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.so $(LDFLAGS) ifdef LLAMA_PORTABLE @@ -550,6 +731,27 @@ ggml-vulkan.o: ggml/src/ggml-vulkan.cpp ggml/include/ggml-vulkan.h ggml/src/ggml # intermediate objects llama.o: src/llama.cpp ggml/include/ggml.h ggml/include/ggml-alloc.h ggml/include/ggml-backend.h ggml/include/ggml-cuda.h ggml/include/ggml-metal.h include/llama.h otherarch/llama-util.h + +# ggml/src/iqk/iqk_quantize.o: \ + # ggml/src/iqk/iqk_quantize.cpp \ + # ggml/src/iqk/iqk_quantize.h \ + # ggml/src/ggml-quants.h ggml/src/ggml-common.h ggml/include/ggml.h ggml/src/ggml-impl.h + +# ifndef GGML_NO_IQKMULMAT +# ggml/src/iqk/iqk_mul_mat.o: \ + # ggml/src/iqk/iqk_mul_mat.cpp \ + # ggml/src/iqk/iqk_mul_mat.h \ + # ggml/src/iqk/iqk_quantize.h \ + # ggml/src/ggml-quants.h ggml/src/ggml-common.h ggml/include/ggml.h ggml/src/ggml-impl.h + # $(CXX) $(CXXFLAGS) -c $< -o $@ +# endif # GGML_NO_IQKMULMAT + +# ifndef GGML_NO_LLAMAFILE +# ggml/src/llamafile/sgemm.o: \ + # ggml/src/llamafile/sgemm.cpp \ + # ggml/src/llamafile/sgemm.h \ + # ggml/include/ggml.h + $(CXX) $(CXXFLAGS) -c $< -o $@ common.o: common/common.cpp common/common.h common/log.h $(CXX) $(CXXFLAGS) -c $< -o $@ diff --git a/README.md b/README.md index e9fc17e2c2780..99664af3474b2 100644 --- a/README.md +++ b/README.md @@ -379,4 +379,276 @@ when you can't use the precompiled binary directly, we provide an automated buil - [Stable Diffusion 1.5 and SDXL safetensor models](https://github.com/LostRuins/koboldcpp/wiki#can-i-generate-images-with-koboldcpp) - [LLaVA based Vision models and multimodal projectors (mmproj)](https://github.com/LostRuins/koboldcpp/wiki#what-is-llava-and-mmproj) - [Whisper models for Speech-To-Text](https://huggingface.co/koboldcpp/whisper/tree/main) - \ No newline at end of file + +####### + +IK_LLAMA.CPP README (Llama.cpp by GGerganov, this cloned version optimized by Ikawrakow) + +####### +# llama.cpp clone with better CPU performance + +[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) + +## TL;DR + +This repository is a clone of [llama.cpp](https://github.com/ggerganov/llama.cpp) with the following improvements +* Better implementation of CPU matrix multiplications (`AVX2` and `ARM_NEON`) for `fp16/fp32` and all k-, i-, and legacy `llama.cpp` quants, that leads to a significant improvement in prompt processing (PP) speed, typically in the range of 2X, but up to 4X for some quantization types. Token generation (TG) also benefits, but to a lesser extent due to TG being memory bound +* Faster CPU inference for MoE models with similar performance gains +* Implementation of the [Bitnet b1.58](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) model for the CPU (`AVX2` and `ARM_NEON`) and GPU (`CUDA` and `Metal`). This implementation is much faster than the unmerged `llama.cpp` [PR-8151](https://github.com/ggerganov/llama.cpp/pull/8151) + +If you are not already familiar with [llama.cpp](https://github.com/ggerganov/llama.cpp), it is better to start there. For those familiar with `llama.cpp`, everything here works the same as in `llama.cpp` (or at least the way `llama.cpp` worked when I last synced on Aug 12 2024). + +Note that I have published some, but not all, of the code in this repository in a series of [llamafile](https://github.com/Mozilla-Ocho/llamafile) PRs ([394](https://github.com/Mozilla-Ocho/llamafile/pull/394), [405](https://github.com/Mozilla-Ocho/llamafile/pull/405), [428](https://github.com/Mozilla-Ocho/llamafile/pull/428), [435](https://github.com/Mozilla-Ocho/llamafile/pull/435), [453](https://github.com/Mozilla-Ocho/llamafile/pull/453), and [464](https://github.com/Mozilla-Ocho/llamafile/pull/464)) + +The implementation of matrix-matrix and matrix-vector multiplications is in a single C++ source file (`iqk_mul_mat.cpp`) with just two interface functions `iqk_mul_mat` (`fp16/fp32` and quantized matrix multiplications) and `iqk_mul_mat_moe` (as `iqk_mul_mat` but meant to be used for the FFN part of a MoE model). Under the hood `iqk_mul_mat_moe` uses the same implementation as `iqk_mul_mat`, with the only difference being where results are stored in memory. Bitnet quantization related stuff is in `iqk-quantize.cpp`. + +## Why? + +Mostly out of curiosity: +* Justine Tunney's `tinyBLAS`, which she contributed to `llama.cpp` in [PR 6414](https://github.com/ggerganov/llama.cpp/pull/6414), only works for `Q4_0`, `Q8_0` and `fp16/bf16` models. In the surrounding discussion about possibly extending `tinyBLAS` to k- and i-quants, she felt that k-quants are [not amenable to block-tiling](https://github.com/ggerganov/llama.cpp/pull/6840#issuecomment-2072995387), which is required to improve performance. This statement piqued my curiosity, so here we are. +* Bitnet-1.58b has been one of the [most discussed topics](https://github.com/ggerganov/llama.cpp/issues/5761#issuecomment-2198380366) in the `llama.cpp` project, so eventually I decided to see how efficiently one can implement a ternary model + +Curiosity aside, improved CPU performance may be (or may become) important in practice. According to The Register, 70% of AI inference [is done on the CPU of mobile phones](https://www.theregister.com/2024/05/30/arm_cortex_x925_ai_cores/?td=rt-3a), at least in the Android world (but I haven't come around to actually comparing performance on a phone). With ever increasing number of LLM model parameters, and with Meta's 400B model just released, the CPU may become the only viable option for people not willing (or not able to) rent/buy uber expensive GPU instances capable of running such models. Granted, one would need a pretty beefy computer to run a 400B model, and inference speed will be sluggish, but at least one will not need to spend the equivalent of a luxury apartment in the downtown of the city where I live to buy the GPU system capable of running the model. + +## Performance comparison to llama.cpp + +The results in the following tables are obtained with these parameters: +* Model is LLaMA-v3-8B for `AVX2` and LLaMA-v2-7B for `ARM_NEON` +* The `AVX2` CPU is a 16-core Ryzen-7950X +* The `ARM_NEON` CPU is M2-Max +* `tinyBLAS` is enabled in `llama.cpp` +* `llama.cpp` results are for `build: 081fe431 (3441)`, which was the current `llama.cpp` master branch when I pulled on July 23 2024. +* The projects are built without `CUDA` support, no `BLAS`, and Accelerate framework disabled + +### Prompt processing + +Here I set the number of threads to be equal to the number of (performance) cores of the CPU, so 16 threads for the Ryzen-7950X and 8 threads for the M2-Max. The following table summarizes the results. To not make the table too long, I have listed only quantized models containing predominantly one quantization type (i.e., excluded the `QX_K - Medium/Large` variants, which are typically a mix of `QX_K` and `Q(X+1)_K`, as well as `IQ2_S` and `IQ3_XS`). + +The command line to generate the benchmark data is +``` +./bin/llama-bench -m $model -p 512 -n 0 -t $num_threads -ngl 0 + +| Quantization| size | backend | threads | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | +| ----------- | ---------: | ---------- | ------: | ---------------: | ---------------: | ------: | +| 8B F16 | 14.96 GiB | AVX2 | 16 | 112.37 ± 0.40 | 131.27 ± 0.38 | 1.168 | +| 7B F16 | 12.55 GiB | NEON | 8 | 90.28 ± 1.25 | 95.34 ± 0.15 | 1.056 | +| 8B Q8_0 | 7.95 GiB | AVX2 | 16 | 118.07 ± 0.53 | 134.00 ± 0.47 | 1.135 | +| 7B Q8_0 | 6.67 GiB | NEON | 8 | 77.25 ± 1.81 | 94.14 ± 1.15 | 1.219 | +| 8B Q4_0 | 4.35 GiB | AVX2 | 16 | 104.46 ± 0.33 | 130.20 ± 0.29 | 1.246 | +| 7B Q4_0 | 3.57 GiB | NEON | 8 | 65.46 ± 0.79 | 76.22 ± 0.71 | 1.164 | +| 8B Q4_1 | 4.77 GiB | AVX2 | 16 | 57.83 ± 0.24 | 160.69 ± 0.49 | 2.779 | +| 7B Q4_1 | 3.95 GiB | NEON | 8 | 37.40 ± 0.50 | 65.83 ± 0.98 | 1.760 | +| 8B Q5_0 | 5.22 GiB | AVX2 | 16 | 53.50 ± 0.35 | 122.62 ± 0.48 | 2.292 | +| 7B Q5_0 | 4.34 GiB | NEON | 8 | 29.31 ± 0.51 | 67.51 ± 1.17 | 2.303 | +| 8B Q5_1 | 5.64 GiB | AVX2 | 16 | 50.85 ± 0.36 | 147.15 ± 0.47 | 2.894 | +| 7B Q5_1 | 4.72 GiB | NEON | 8 | 26.02 ± 0.37 | 58.49 ± 0.85 | 2.248 | +| 8B Q2_K_S | 2.78 GiB | AVX2 | 16 | 110.11 ± 0.28 | 192.47 ± 1.35 | 1.748 | +| 7B Q2_K_S | 2.16 GiB | NEON | 8 | 35.44 ± 0.06 | 77.93 ± 1.64 | 2.199 | +| 8B Q3_K_S | 3.41 GiB | AVX2 | 16 | 77.42 ± 0.36 | 181.64 ± 0.44 | 2.346 | +| 7B Q3_K_S | 2.75 GiB | NEON | 8 | 26.79 ± 0.03 | 59.38 ± 1.08 | 2.216 | +| 8B Q4_K_S | 4.36 GiB | AVX2 | 16 | 98.92 ± 0.34 | 185.35 ± 0.39 | 1.874 | +| 7B Q4_K_S | 3.59 GiB | NEON | 8 | 46.55 ± 0.67 | 76.31 ± 0.38 | 1.639 | +| 8B Q5_K_S | 5.21 GiB | AVX2 | 16 | 69.44 ± 0.31 | 179.62 ± 0.69 | 2.587 | +| 7B Q5_K_S | 4.33 GiB | NEON | 8 | 30.18 ± 0.23 | 65.34 ± 0.79 | 2.165 | +| 8B Q6_K | 6.14 GiB | AVX2 | 16 | 74.89 ± 0.26 | 181.86 ± 0.55 | 2.428 | +| 7B Q6_K | 5.15 GiB | NEON | 8 | 28.12 ± 1.24 | 60.75 ± 1.15 | 2.160 | +| 8B IQ2_XXS | 2.23 GiB | AVX2 | 16 | 42.57 ± 0.16 | 126.63 ± 0.55 | 2.975 | +| 7B IQ2_XXS | 1.73 GiB | NEON | 8 | 20.87 ± 0.20 | 64.29 ± 1.12 | 3.080 | +| 8B IQ2_XS | 2.42 GiB | AVX2 | 16 | 46.45 ± 0.27 | 125.46 ± 0.43 | 2.701 | +| 7B IQ2_XS | 1.89 GiB | NEON | 8 | 22.77 ± 0.21 | 51.15 ± 0.24 | 2.246 | +| 8B IQ2_M | 2.74 GiB | AVX2 | 16 | 40.76 ± 0.18 | 113.07 ± 0.48 | 2.774 | +| 7B IQ2_M | 2.20 GiB | NEON | 8 | 14.95 ± 0.26 | 44.87 ± 0.50 | 3.001 | +| 8B IQ3_XXS | 3.04 GiB | AVX2 | 16 | 31.95 ± 0.20 | 109.86 ± 0.45 | 3.438 | +| 7B IQ3_XXS | 2.41 GiB | NEON | 8 | 14.40 ± 0.10 | 53.58 ± 0.85 | 3.721 | +| 8B IQ3_S | 3.42 GiB | AVX2 | 16 | 28.04 ± 0.08 | 96.28 ± 0.45 | 3.434 | +| 7B IQ3_S | 2.75 GiB | NEON | 8 | 12.08 ± 0.30 | 49.72 ± 0.06 | 4.116 | +| 8B IQ4_XS | 4.13 GiB | AVX2 | 16 | 68.98 ± 0.31 | 180.34 ± 0.55 | 2.614 | +| 7B IQ4_XS | 3.37 GiB | NEON | 8 | 40.67 ± 1.97 | 75.11 ± 1.97 | 1.847 | +| 8B IQ4_NL | 4.35 GiB | AVX2 | 16 | 59.94 ± 0.21 | 129.06 ± 0.43 | 2.153 | +| 7B IQ4_NL | 3.56 GiB | NEON | 8 | 34.36 ± 0.81 | 76.02 ± 1.36 | 2.212 | + +We see that `llama.cpp` achieves respectable performance for `fp16`, `Q8_0`, and `Q4_0`, being only up to 25% slower than this implementation. This is thanks to the use of Justine Tunney's `tinyBLAS`, which is utilized for these quantization types. For all other quants we observe performance gains in the `1.75X - 4X` range, which is not a small feat considering that the `ggml` matrix multiplication functions has been rewritten several times since `llama.cpp` was first published. Performance gains are larger for i-quants due to the higher quant unpacking cost (see discussion in "To tile or not to tile") + +### Token generation + +On the Ryzen-7950X TG is memory bound, and for many quantization types peak performance is achieved at just 4 threads. Hence, only results for 2 and 4 threads are shown for `AVX2`. The M2-Max has a much more capable memory subsystem and as a result performance keep increasing up to 8 threads. Thus, results are given for up to 8 threads for `ARM_NEON`. + +The command line to generate the data was +``` +./bin/llama-bench -m $model -p 0 -n 128 -t $num_threads -ngl 0 +``` + +| Quantization| size | backend | threads | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | +| ---------- | ---------: | ---------- | ------: | ---------------: | ---------------: | ------: | +| 8B F16 | 14.96 GiB | AVX2 | 1 | 2.20 ± 0.00 | 2.25 ± 0.00 | 1.023 | +| | | | 2 | 3.63 ± 0.00 | 3.68 ± 0.00 | 1.014 | +| | | | 4 | 4.20 ± 0.00 | 4.20 ± 0.00 | 1.000 | +| 7B F16 | 12.55 GiB | NEON | 2 | 6.94 ± 0.27 | 7.40 ± 0.01 | 1.066 | +| | | | 4 | 8.73 ± 0.01 | 8.83 ± 0.01 | 1.011 | +| | | | 6 | 9.05 ± 0.02 | 9.05 ± 0.01 | 1.000 | +| 8B Q8_0 | 7.95 GiB | AVX2 | 2 | 5.03 ± 0.00 | 7.87 ± 0.00 | 1.565 | +| | | | 4 | 7.40 ± 0.00 | 7.82 ± 0.00 | 1.057 | +| 7B Q8_0 | 6.67 GiB | NEON | 2 | 8.29 ± 0.44 | 12.07 ± 0.10 | 1.456 | +| | | | 4 | 13.53 ± 0.03 | 15.77 ± 0.08 | 1.166 | +| | | | 8 | 16.24 ± 0.10 | 16.94 ± 0.04 | 1.043 | +| 8B Q4_0 | 4.35 GiB | AVX2 | 2 | 6.36 ± 0.00 | 10.28 ± 0.00 | 1.616 | +| | | | 4 | 10.97 ± 0.06 | 13.55 ± 0.07 | 1.235 | +| 7B Q4_0 | 3.57 GiB | NEON | 2 | 9.77 ± 0.02 | 13.69 ± 0.03 | 1.401 | +| | | | 4 | 17.82 ± 0.06 | 23.98 ± 0.11 | 1.346 | +| | | | 8 | 26.63 ± 0.41 | 29.86 ± 0.04 | 1.121 | +| 8B Q4_1 | 4.77 GiB | AVX2 | 2 | 5.11 ± 0.00 | 11.45 ± 0.00 | 2.241 | +| | | | 4 | 9.08 ± 0.02 | 12.58 ± 0.00 | 1.385 | +| 7B Q4_1 | 3.95 GiB | NEON | 2 | 9.11 ± 0.06 | 14.62 ± 0.04 | 1.605 | +| | | | 4 | 17.04 ± 0.09 | 24.08 ± 0.28 | 1.413 | +| | | | 8 | 25.26 ± 0.24 | 27.23 ± 0.14 | 1.078 | +| 8B Q5_0 | 5.22 GiB | AVX2 | 2 | 5.31 ± 0.01 | 8.30 ± 0.01 | 1.563 | +| | | | 4 | 9.40 ± 0.01 | 11.47 ± 0.00 | 1.220 | +| 7B Q5_0 | 4.34 GiB | NEON | 2 | 7.26 ± 0.06 | 7.52 ± 0.00 | 1.036 | +| | | | 4 | 13.63 ± 0.18 | 14.16 ± 0.10 | 1.039 | +| | | | 8 | 22.55 ± 0.35 | 24.34 ± 0.22 | 1.079 | +| 8B Q5_1 | 5.64 GiB | AVX2 | 2 | 4.52 ± 0.00 | 8.86 ± 0.00 | 1.960 | +| | | | 4 | 7.72 ± 0.05 | 10.68 ± 0.03 | 1.383 | +| 7B Q5_1 | 4.72 GiB | NEON | 2 | 6.51 ± 0.01 | 6.42 ± 0.03 | 0.986 | +| | | | 4 | 12.26 ± 0.18 | 12.21 ± 0.14 | 0.996 | +| | | | 8 | 20.33 ± 0.52 | 21.85 ± 0.22 | 1.075 | +| 8B Q2_K_S | 2.78 GiB | AVX2 | 2 | 11.30 ± 0.00 | 13.06 ± 0.01 | 1.156 | +| | | | 4 | 18.70 ± 0.00 | 19.04 ± 0.65 | 1.014 | +| 7B Q2_K_S | 2.16 GiB | NEON | 2 | 8.42 ± 0.05 | 11.97 ± 0.10 | 1.422 | +| | | | 4 | 15.74 ± 0.01 | 22.09 ± 0.08 | 1.403 | +| | | | 8 | 27.35 ± 0.05 | 38.32 ± 0.05 | 1.401 | +| 8B Q3_K_S | 3.41 GiB | AVX2 | 2 | 8.58 ± 0.00 | 10.82 ± 0.00 | 1.261 | +| | | | 4 | 15.26 ± 0.01 | 16.25 ± 0.01 | 1.065 | +| 7B Q3_K_S | 2.75 GiB | NEON | 2 | 6.40 ± 0.02 | 9.12 ± 0.09 | 1.425 | +| | | | 4 | 12.17 ± 0.00 | 17.11 ± 0.03 | 1.406 | +| | | | 8 | 22.04 ± 0.08 | 31.39 ± 0.31 | 1.424 | +| 8B Q4_K_S | 4.36 GiB | AVX2 | 2 | 9.61 ± 0.00 | 10.72 ± 0.01 | 1.116 | +| | | | 4 | 13.24 ± 0.31 | 13.28 ± 0.01 | 1.003 | +| 7B Q4_K_S | 3.59 GiB | NEON | 2 | 11.15 ± 0.05 | 12.93 ± 0.09 | 1.160 | +| | | | 4 | 20.24 ± 0.16 | 23.49 ± 0.29 | 1.161 | +| | | | 8 | 25.76 ± 0.07 | 28.31 ± 0.22 | 1.099 | +| 8B Q5_K_S | 5.21 GiB | AVX2 | 2 | 7.45 ± 0.00 | 9.73 ± 0.00 | 1.306 | +| | | | 4 | 11.05 ± 0.33 | 11.43 ± 0.02 | 1.034 | +| 7B Q5_K_S | 4.33 GiB | NEON | 2 | 7.20 ± 0.04 | 8.81 ± 0.04 | 1.224 | +| | | | 4 | 13.62 ± 0.15 | 16.81 ± 0.16 | 1.234 | +| | | | 8 | 20.56 ± 0.19 | 23.96 ± 0.14 | 1.165 | +| 8B Q6_K | 6.14 GiB | AVX2 | 2 | 7.53 ± 0.00 | 9.42 ± 0.00 | 1.251 | +| | | | 4 | 9.74 ± 0.00 | 9.97 ± 0.01 | 1.024 | +| 7B Q6_K | 5.15 GiB | NEON | 2 | 6.85 ± 0.04 | 8.30 ± 0.06 | 1.212 | +| | | | 4 | 13.03 ± 0.05 | 15.47 ± 0.17 | 1.187 | +| | | | 8 | 18.52 ± 0.07 | 20.67 ± 0.08 | 1.116 | +| 8B IQ2_XXS | 2.23 GiB | AVX2 | 2 | 5.33 ± 0.01 | 6.40 ± 0.00 | 1.201 | +| | | | 4 | 10.06 ± 0.03 | 11.76 ± 0.03 | 1.169 | +| 7B IQ2_XXS | 1.73 GiB | NEON | 2 | 5.07 ± 0.04 | 5.22 ± 0.05 | 1.030 | +| | | | 4 | 9.63 ± 0.00 | 9.91 ± 0.07 | 1.029 | +| | | | 8 | 17.40 ± 0.50 | 18.65 ± 0.22 | 1.072 | +| 8B IQ2_XS | 2.42 GiB | AVX2 | 2 | 5.83 ± 0.00 | 6.55 ± 0.00 | 1.123 | +| | | | 4 | 10.88 ± 0.09 | 12.07 ± 0.07 | 1.109 | +| 7B IQ2_XS | 1.89 GiB | NEON | 2 | 5.52 ± 0.01 | 5.60 ± 0.00 | 1.014 | +| | | | 4 | 10.50 ± 0.01 | 11.15 ± 0.00 | 1.062 | +| | | | 8 | 18.19 ± 1.30 | 20.94 ± 0.19 | 1.151 | +| 8B IQ2_M | 2.74 GiB | AVX2 | 2 | 5.12 ± 0.01 | 5.17 ± 0.00 | 1.010 | +| | | | 4 | 9.60 ± 0.28 | 9.68 ± 0.16 | 1.008 | +| 7B IQ2_M | 2.20 GiB | NEON | 2 | 3.73 ± 0.02 | 4.53 ± 0.00 | 1.214 | +| | | | 4 | 7.14 ± 0.05 | 8.70 ± 0.06 | 1.218 | +| | | | 8 | 11.99 ± 0.48 | 16.41 ± 0.05 | 1.369 | +| 8B IQ3_XXS | 3.04 GiB | AVX2 | 2 | 4.06 ± 0.01 | 5.00 ± 0.00 | 1.232 | +| | | | 4 | 7.75 ± 0.02 | 9.13 ± 0.45 | 1.178 | +| 7B IQ3_XXS | 2.41 GiB | NEON | 2 | 3.53 ± 0.00 | 3.82 ± 0.00 | 1.082 | +| | | | 4 | 6.74 ± 0.04 | 7.42 ± 0.07 | 1.103 | +| | | | 8 | 11.96 ± 0.40 | 13.19 ± 0.29 | 1.103 | +| 8B IQ3_S | 3.42 GiB | AVX2 | 2 | 3.62 ± 0.00 | 4.06 ± 0.00 | 1.122 | +| | | | 4 | 6.80 ± 0.01 | 7.62 ± 0.10 | 1.121 | +| 7B IQ3_S | 2.75 GiB | NEON | 2 | 2.96 ± 0.01 | 3.21 ± 0.03 | 1.084 | +| | | | 4 | 5.68 ± 0.01 | 6.25 ± 0.05 | 1.100 | +| | | | 8 | 10.32 ± 0.25 | 11.11 ± 0.37 | 1.077 | +| 8B IQ4_XS | 4.13 GiB | AVX2 | 2 | 8.08 ± 0.00 | 11.35 ± 0.00 | 1.405 | +| | | | 4 | 13.36 ± 0.72 | 14.32 ± 0.24 | 1.072 | +| 7B IQ4_XS | 3.37 GiB | NEON | 2 | 9.87 ± 0.03 | 12.06 ± 0.00 | 1.222 | +| | | | 4 | 17.78 ± 0.23 | 22.06 ± 0.28 | 1.241 | +| | | | 8 | 27.62 ± 0.09 | 29.70 ± 0.39 | 1.075 | +| 8B IQ4_NL | 4.35 GiB | AVX2 | 2 | 5.52 ± 0.00 | 10.26 ± 0.00 | 1.859 | +| | | | 4 | 10.78 ± 0.01 | 13.69 ± 0.08 | 1.270 | +| 7B IQ4_NL | 3.56 GiB | NEON | 2 | 8.32 ± 0.01 | 13.54 ± 0.01 | 1.627 | +| | | | 4 | 15.89 ± 0.00 | 24.28 ± 0.29 | 1.528 | +| | | | 8 | 26.56 ± 0.36 | 29.87 ± 0.08 | 1.125 | + +Here gains are generally lower compared to PP due to TG performance being limited by memory bandwidth. Nevertheless, for some quants/architectures/threads the speedup is quite remarkable (e.g., almost a factor of 2 for `Q5_1` on `AVX2` with 2 threads). + +## MoE models + +There is [PR-6840](https://github.com/ggerganov/llama.cpp/pull/6840) from Justine Tunney in `llama.cpp`, but it has not been merged since April 23, so I'll compare performance to the master branch for Mixtral-8x7B. As Mixtral8x7B quantization is quite a lengthy process, the following table shows data only for `Q4_K_S` (a commonly used k-quant, 4 bit), `Q5_0` (a legacy quant, 5 bit), and `IQ4_XXS` (a 3-bit i-quant) + +| model | size | backend | threads | test | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | +| ------------ | ---------: | ---------- | ------: | -------: | ---------------: | ---------------: | ------: | +| 8x7B Q4_K_S | 48.75 GiB | AVX2 | 16 | pp512 | 54.92 ± 0.23 | 102.94 ± 0.37 | 1.874 | +| | | NEON | 8 | pp512 | 23.54 ± 1.56 | 38.32 ± 0.54 | 1.628 | +| | | AVX2 | 4 | tg128 | 7.80 ± 0.07 | 7.83 ± 0.09 | 1.004 | +| | | NEON | 8 | tg128 | 14.95 ± 0.25 | 15.28 ± 0.24 | 2.022 | +| 8x7B IQ3_XXS | 33.07 GiB | AVX2 | 16 | pp512 | 17.58 ± 0.04 | 68.45 ± 0.22 | 3.894 | +| | | NEON | 8 | pp512 | 7.75 ± 0.04 | 34.67 ± 0.40 | 4.474 | +| | | AVX2 | 4 | tg128 | 4.60 ± 0.01 | 5.45 ± 0.09 | 1.185 | +| | | AVX2 | 8 | tg128 | 8.04 ± 0.65 | 9.83 ± 0.06 | 1.223 | +| | | AVX2 | 16 | tg128 | 10.42 ± 0.01 | 10.57 ± 0.01 | 1.014 | +| | | NEON | 8 | tg128 | 6.19 ± 1.16 | 7.27 ± 0.14 | 1.174 | +| 8x7B Q5_0 | 59.11 GiB | AVX2 | 16 | pp512 | 29.06 ± 0.43 | 62.67 ± 0.32 | 2.157 | +| | | NEON | 8 | pp512 | 15.17 ± 0.51 | 27.36 ± 1.03 | 1.804 | +| | | AVX2 | 4 | tg128 | 5.44 ± 0.10 | 6.81 ± 0.06 | 1.252 | +| | | NEON | 8 | tg128 | 12.03 ± 0.77 | 12.41 ± 1.27 | 1.032 | + + +## Bitnet-1.58B + +Two implementations are provided +* `IQ1_BN` - uses 1.625 bits-per-weight (bpw) +* `IQ2_BN` - uses 2.0 bpw + +`IQ2_BN` is faster for PP (CPU and GPU, although the PP performance difference on CUDA is very minor). `IQ1_BN` can arrive at a higher TG performance on the Ryzen-7950X (given enough threads) because of the smaller model size, but it is always slower on the GPU and on the M2-Max CPU. + +There is the unmerged [PR 8151](https://github.com/ggerganov/llama.cpp/pull/8151) in `llama.cpp` that implements Bitnet-1.58B for the CPU (`AVX` and `ARM_NEON`, no GPU implementation). The following table compares performance between this repo and `PR-8151` in `llama.cpp`. The CUDA results were obtained on an RTX-4080, the Metal results on a 30-core M2-Max GPU. + +| model | size | backend | threads | test | t/s (llama.cpp) | t/s (this repo)| Speedup | +| ----------- | ---------: | ---------- | ------: | -----: | ---------------: | -------------: | ------: | +| 3B - IQ1_BN | 729.64 MiB | AVX2 | 16 | pp512 | 120.61 ± 0.48 | 423.19 ± 1.28 | 3.509 | +| | | NEON | 8 | pp512 | 46.64 ± 0.02 | 205.90 ± 0.88 | 4.415 | +| | | CUDA | 8 | pp512 | - | 10660 ± 170 | - | +| | | Metal | 8 | pp512 | - | 698.25 ± 1.91 | - | +| | | AVX2 | 2 | tg128 | 15.79 ± 0.01 | 22.13 ± 0.02 | 1.402 | +| | | AVX2 | 4 | tg128 | 28.64 ± 1.72 | 40.14 ± 0.04 | 1.402 | +| | | AVX2 | 8 | tg128 | 48.91 ± 0.08 | 61.79 ± 0.09 | 1.263 | +| | | AVX2 | 16 | tg128 | 57.73 ± 0.05 | 60.79 ± 0.05 | 1.053 | +| | | NEON | 2 | tg128 | 11.43 ± 0.04 | 16.87 ± 0.02 | 1.476 | +| | | NEON | 4 | tg128 | 21.11 ± 0.05 | 30.66 ± 0.11 | 1.452 | +| | | NEON | 8 | tg128 | 37.36 ± 0.07 | 55.21 ± 0.16 | 1.478 | +| | | CUDA | 8 | tg128 | - | 301.44 ± 0.12 | - | +| | | Metal | 8 | tg128 | - | 76.70 ± 0.07 | - | +| 3B - IQ2_BN | 873.65 MiB | AVX2 | 16 | pp512 | 151.39 ± 0.35 | 540.82 ± 2.48 | 3.572 | +| | | NEON | 8 | pp512 | 46.54 ± 0.03 | 242.05 ± 0.34 | 5.201 | +| | | CUDA | 8 | pp512 | - | 10800 ± 160 | - | +| | | Metal | 8 | pp512 | - | 723.19 ± 0.53 | - | +| | | AVX2 | 2 | tg128 | 18.93 ± 0.02 | 38.34 ± 0.08 | 2.026 | +| | | AVX2 | 4 | tg128 | 34.54 ± 0.06 | 56.29 ± 0.07 | 1.630 | +| | | AVX2 | 8 | tg128 | 52.97 ± 0.07 | 53.44 ± 0.08 | 1.009 | +| | | AVX2 | 16 | tg128 | 51.84 ± 0.25 | 53.46 ± 0.07 | 1.031 | +| | | NEON | 2 | tg128 | 11.40 ± 0.02 | 32.01 ± 0.27 | 2.808 | +| | | NEON | 4 | tg128 | 20.99 ± 0.00 | 56.45 ± 0.11 | 2.689 | +| | | NEON | 8 | tg128 | 37.28 ± 0.08 | 89.77 ± 0.70 | 2.408 | +| | | CUDA | 8 | tg128 | - | 322.10 ± 0.07 | - | +| | | Metal | 8 | tg128 | - | 110.39 ± 0.13 | - | + +We can make the following observations: +* For prompt processing this Bitnet-1.58b implementation is massively better than PR-8151 in `llama.cpp`, with gains between 3.4X and 5.2X! +* We get `PP-512 = 520 t/s` for the 2.0 bpw variant on the Ryzen-7950X, which costs less than $500. Hey, who needs a GPU? +* For low number of threads (2), this implementation is also much faster than PR-8151 for TG, where speed gains are between 1.4X and 2.8X. As we become memory bound on the Ryzen-7950X, the speed advantage goes away there for sufficiently high number of threads. But on the M2-Max this implementation is 1.4X (1.625 bpw) or 2.4X faster even at 8 threads +* Looking at TG on the M2-Max, the GPU looks a bit like wasted silicon (90 vs 110 t/s for TG-128 and the 2.0 bpw variant). If the GPU transistors had been spent to double the M2 number of CPU cores (and all memory bandwidth is given to the CPU), the CPU would be wiping the floor with the GPU. +* I'm of course kidding with the above. Still, it seems there are massive inefficiencies in the `llama.cpp` Metal implementation that start showing up when matrix multiplications become very fast as is the case here. The difference between CPU and GPU prompt processing speed is typically at least a factor of 7 in favor of the GPU on the M2-Max, but it is only around a factor of 3 here. +* It is worth noting that one needs to offload the token embeddings tensor to the GPU, else performance on CUDA/Metal is significantly lower. Bitnet uses the same tensor for token embeddings and for output. Mainline `llama.cpp` currently puts the token embeddings tensor on the CPU, and this results in running the matrix multiplication with the output tensor on the CPU. This most likely affects other models as well (e.g., Gemma), but I haven't yet looked into this. + +To reproduce these results: +* Clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B +* Run `python3 --outtype f16 path_to_bitnet` to convert to GGUF +* Run `./bin/llama-quantize path_to_bitnet/ggml-model-f16.gguf quantized.gguf [iq1_bn | iq2_bn]`. Note: no imatrix is required (and, if you provide one, it is ignored) +* Caveat: only the 3B Bitnet variant works. The smaller Bitnet models contain tensors with number of columns that are not even a multiple of 32, so basically no `llama.cpp` quant will work for these. + +## To tile or not to tile + +The common wisdom for efficient matrix multiplications is to use block tiling, and this is also used here for `fp16/fp32` matrices. But block tiling does not somehow magically reduce the amount of computation that needs to get done. Performance gains are simply due to the better utilization of memory caches. When dealing with quantized matrix multiplications, there is an additional factor that comes into play: the quantized data needs to be unpacked to 8-bit integers before being used in the matrix multiplication multiply-add operations. Depending on quantization type, this unpacking can represent a significant fraction of the overall computation cost. Hence, for best performance, one would want to reuse the unpacked quants as much as possible, thus spending some fraction of the available vector registers to hold the unpacked data. But when using block tiling, one also needs a certain number of vector registers for accumulating results. For instance, on `AVX2` (16 vector registers available), for `fp16/fp32` models best performance is achieved with `2 x 6` tiles (where the `2` refers to rows in the left matrix and is measured in units of the vector register size, so 16/8 floats for `fp16/fp32`, and `6` is for the number of columns in the right matrix). Unpacking quantized data works best when done in blocks of 128 or 256 quants so that, if we wanted to keep unpacked quants for 2 rows, we would need at least 8 vector registers, thus being left with less than 8 registers for result accumulation, so at best `2 x 3` tiles. In practice one needs addition vector registers for various constants that are typically needed for de-quantization, so that, at the end, it becomes better to use `1 x N` "tiles", i.e., a row-wise multiplication where each row in the left matrix is multiplied with `N` columns in the right matrix, thus reusing the unpacked data `N` times. This (i.e., amortizing de-quantization cost) is the main mechanism for seeding up quantized matrix multiplications. Having started with quantized matrices, and having gone from tiles to a row-wise implementation after some experimentation, I did try row-wise multiplication for float matrices first. Performance was not quite as good as for block-tiling, but I did get up to 90-95% of the speed of `tinyBLAS` that way before switching the `fp16/fp32` implementation to `2 x 6` (`AVX2`) or `5 x 5` (`AVX512` and `ARM_NEON`) block-tiles. But even for for `Q8_0 x Q8_0` multiplications, where there is basically no de-quantization cost, row-wise multiplication is faster than tiling (and hence this implemeintation beats `tinyBLAS`, which uses block-tiling also for `Q8_0`). diff --git a/common/common.cpp b/common/common.cpp index 4133e1f1da366..386115933cb4a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1321,6 +1321,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.process_output = true; return true; } + if (arg == "--output-tensor-name") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.output_tensor_name = argv[i]; + return true; + } if (arg == "--no-ppl") { params.compute_ppl = false; return true; diff --git a/common/common.h b/common/common.h index 9c054ed7e4e22..1df1cfcdf3afd 100644 --- a/common/common.h +++ b/common/common.h @@ -268,6 +268,7 @@ struct gpt_params { // imatrix params std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file + std::string output_tensor_name = "output.weight"; // name of the output tensor int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 550dd5cfda99f..b470a0883b4a2 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1617,23 +1617,40 @@ def weight_quant(self, weight): return weight.type(dtype), scale.type(torch.float32) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - new_name = self.map_tensor_name(name) - - if any(self.match_model_tensor_name(new_name, key, bid) for key in [ - gguf.MODEL_TENSOR.ATTN_Q, - gguf.MODEL_TENSOR.ATTN_K, - gguf.MODEL_TENSOR.ATTN_V, - gguf.MODEL_TENSOR.ATTN_OUT, - gguf.MODEL_TENSOR.FFN_UP, - gguf.MODEL_TENSOR.FFN_DOWN, - gguf.MODEL_TENSOR.FFN_GATE, - ]): - # transform weight into 1/0/-1 (in fp32) + # transform weight into 1/0/-1 (in fp32) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): weight_torch, scale_torch = self.weight_quant(data_torch) - yield (new_name, weight_torch) - yield (new_name.removesuffix(".weight") + ".scale", scale_torch) - else: - yield (new_name, data_torch) + + tensors: list[tuple[str, Tensor]] = [] + + if name.endswith("q_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch)) + elif name.endswith("k_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch)) + elif name.endswith("v_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch)) + elif name.endswith("o_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch)) + elif name.endswith("up_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch)) + elif name.endswith("down_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch)) + elif name.endswith("gate_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch)) + + if len(tensors) == 0: + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors @Model.register("GrokForCausalLM") diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 68b01126b170d..f137ada5e749e 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -1,4 +1,11 @@ #include "build-info.h" +// +// Copyright (C) 2024 Iwan Kawrakow +// Copyright (C) 2023-2024 The ggml authors +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.h" #include "llama.h" @@ -84,7 +91,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (t->op != GGML_OP_MUL_MAT) return false; // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; - if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false; + //printf("wname = %s\n", wname.c_str()); + if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == m_params.output_tensor_name))) return false; return true; } diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index c4bf027ddd1c5..6e9e9beaa9819 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -342,6 +342,10 @@ int main(int argc, char ** argv) { if (!layer_included(params, kv_tensor.first)) { continue; } + if (kv_tensor.second->ne[0] == 1 || kv_tensor.second->ne[1] == 1) { + // we never quantize those + continue; + } if (params.verbose) { printf("%s: type %s, size %" PRId64 "\n", kv_tensor.first.c_str(), ggml_type_name(kv_tensor.second->type), ggml_nelements(kv_tensor.second)); } @@ -387,6 +391,10 @@ int main(int argc, char ** argv) { if (!layer_included(params, kv_tensor.first)) { continue; } + if (kv_tensor.second->ne[0] == 1 || kv_tensor.second->ne[1] == 1) { + // we never quantize those + continue; + } if (params.verbose) { printf(" %s ...\n", kv_tensor.first.c_str()); } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index b19653b7f79df..764402b77532e 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -17,10 +17,10 @@ struct quant_option { }; static const std::vector QUANT_OPTIONS = { - { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, - { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", }, - { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", }, - { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", }, + { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 3.56G, +0.2166 ppl @ LLaMA-v1-7B", }, + { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", }, + { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", }, + { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, { "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", }, @@ -33,24 +33,32 @@ static const std::vector QUANT_OPTIONS = { { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", }, + { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, + { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, + { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, - { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, - { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization", }, - { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 3.41G, +1.6321 ppl @ Llama-3-8B", }, - { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.74G, +0.6569 ppl @ Llama-3-8B", }, - { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 4.03G, +0.5562 ppl @ Llama-3-8B", }, + { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, + { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , }, + { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, + { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", }, + { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, - { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, - { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 4.37G, +0.2689 ppl @ Llama-3-8B", }, - { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 4.58G, +0.1754 ppl @ Llama-3-8B", }, - { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, - { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 5.21G, +0.1049 ppl @ Llama-3-8B", }, - { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, - { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, - { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, + { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, + { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, + { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, + { "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", }, + { "IQ6_K", LLAMA_FTYPE_MOSTLY_IQ6_K, " 6.6 bpw non-linear quantization", }, + { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, + { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, + { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", }, + { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, + { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", }, + { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", }, + { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, + { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, @@ -316,6 +324,8 @@ int main(int argc, char ** argv) { for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { params.quantize_output_tensor = false; + } else if (strcmp(argv[arg_idx], "--ignore-imatrix-rules") == 0) { + params.ignore_imatrix_rules = true; } else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) { if (arg_idx < argc-1) { params.output_tensor_type = parse_ggml_type(argv[++arg_idx]); @@ -470,11 +480,12 @@ int main(int argc, char ** argv) { } } - if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || + if (!params.ignore_imatrix_rules && imatrix_data.empty() && + (params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) { + params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M)) { fprintf(stderr, "\n==========================================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); fprintf(stderr, "==========================================================================================================\n\n\n"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a89887232a370..e2c5b277b4692 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -401,6 +401,15 @@ extern "C" { GGML_TYPE_TQ2_0 = 35, GGML_TYPE_Q2_2 = 36, GGML_TYPE_Q1_3 = 37, + GGML_TYPE_IQ1_BN = 134, + GGML_TYPE_IQ2_BN = 135, + GGML_TYPE_Q8_K64 = 136, + GGML_TYPE_IQ2_K = 137, + GGML_TYPE_IQ3_K = 138, + GGML_TYPE_IQ4_K = 139, + GGML_TYPE_IQ5_K = 140, + GGML_TYPE_IQ6_K = 141, + GGML_TYPE_IQ2_TN = 142, GGML_TYPE_COUNT, }; @@ -445,6 +454,14 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_K = 30, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ6_K = 34, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_TN = 35, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 3be4dd4ca783d..73cf6d2e2f105 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #ifndef GGML_COMMON_DECL #if defined(GGML_COMMON_DECL_C) @@ -136,9 +143,18 @@ typedef sycl::half2 ggml_half2; #define QI4_XS (QK_K / (4*QR4_XS)) #define QR4_XS 2 +#define QI5_XS (QK_K / (4*QR5_XS)) +#define QR5_XS 2 + +#define QI6_XS (QK_K / (4*QR6_XS)) +#define QR6_XS 2 + #define QI3_S (QK_K / (4*QR3_S)) #define QR3_S 4 +#define QI1_BN (QK_IQ1BN / (4*QR1_BN)) +#define QR1_BN 8 + #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP // 1.625 bpw for BitNet b1.58 models @@ -217,6 +233,17 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); +typedef struct { + ggml_half d[8]; + int8_t qs[4*QK8_1]; +} block_q8_1_x4; +static_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), "wrong q8_1_x4 block size/padding"); +typedef struct { + ggml_half d[4]; + int8_t qs[4*QK8_0]; +} block_q8_0_x4; +static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding"); + typedef struct { ggml_half d[4]; // deltas for 4 q4_0 blocks uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks @@ -347,6 +374,16 @@ typedef struct { int16_t bsums[QK_K/16]; // sum of quants in groups of 16 } block_q8_K; static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); +typedef struct { + float d; // delta + int8_t qs[64]; // quants +} block_q8_K64; +static_assert(sizeof(block_q8_K64) == sizeof(float) + 64, "wrong q8_K64 block size/padding"); +typedef struct { + float d; // delta + int8_t qs[128]; // quants +} block_q8_K128; +static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding"); // (Almost) "true" 2-bit quantization. // Due to the need to use blocks as per ggml design, it ends up using @@ -410,6 +447,34 @@ typedef struct { } block_iq1_m; static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); +// +// Bitnet - implemented as 1.625 bpw +// The block scale is a waste, but it allows us to plug it in without any additional +// changes to ggml. +// +#define QK_IQ1BN 64 +typedef struct { + uint8_t ql[12]; + uint8_t extra; +} block_iq1_bn; +static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding"); +// +// Bitnet - implemented as 2.0 bpw +// +#define QK_IQ2BN 64 +typedef struct { + uint8_t qs[QK_IQ2BN/4]; +} block_iq2_bn; +static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding"); +// +// TriLM - implemented as 2.0625 bpw +// +typedef struct { + ggml_half d; + uint8_t qs[QK_K/4]; +} block_iq2_tn; +static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding"); + // Used by IQ1_M quants typedef union { ggml_half f16; @@ -432,6 +497,53 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +typedef struct { + ggml_half d; + uint16_t extra; + uint8_t scales[QK_K/32]; + uint8_t qs[QK_K/4]; +} block_iq2_k; +static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding"); + +typedef struct { + ggml_half d; + uint16_t extra; + uint16_t scales_h; + uint8_t scales_l[QK_K/32]; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/8]; +} block_iq3_k; +static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding"); + +typedef struct { + ggml_half d; + uint16_t extra; + uint8_t scales_h[QK_K/64]; + uint8_t scales_l[QK_K/32]; + uint8_t qs[QK_K/2]; +} block_iq4_k; +static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding"); + +typedef struct { + ggml_half d; + uint16_t extra; + uint8_t scales_h[QK_K/64]; + uint8_t scales_l[QK_K/32]; + uint8_t qs[QK_K/2]; + uint8_t qh[QK_K/8]; +} block_iq5_k; +static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding"); + +typedef struct { + ggml_half d; + uint16_t extra; + int8_t scales[QK_K/16]; + uint8_t qs[QK_K/2]; + uint8_t qh[QK_K/4]; +} block_iq6_k; +static_assert(sizeof(block_iq6_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/4 + QK_K/16, "wrong iq6_k block size/padding"); + + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL @@ -1898,5 +2010,35 @@ GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S) GGML_TABLE_END() #endif +GGML_TABLE_BEGIN(int8_t, iq2nl_values, 8) + -31, -13, 1, 17, -26, -8, 6, 22 +GGML_TABLE_END() + +GGML_TABLE_BEGIN(int8_t, iq3nl_values, 16) + -63, -40, -23, -10, 1, 13, 28, 47, + -59, -36, -19, -6, 5, 17, 32, 51, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(int8_t, iq4k_values, 32) + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, + -123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117 +GGML_TABLE_END() + +GGML_TABLE_BEGIN(int8_t, iq5nl_values, 64) + -126, -114, -103, -92, -83, -74, -65, -57, -50, -43, -36, -30, -24, -18, -12, -6, -1, 5, 11, 17, 23, 29, 36, 43, 51, 59, 68, 77, 87, 97, 109, 121, + -124, -112, -101, -90, -81, -72, -63, -55, -48, -41, -34, -28, -22, -16, -10, -4, 1, 7, 13, 19, 25, 31, 38, 45, 53, 61, 70, 79, 89, 99, 111, 123, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(int8_t, iq6nl_values, 128) + -127, -121, -115, -109, -104, -98, -93, -88, -84, -79, -74, -70, -66, -62, -58, -54, + -51, -47, -44, -40, -37, -34, -31, -28, -25, -22, -19, -16, -13, -11, -8, -5, + -2, 0, 3, 6, 9, 12, 14, 17, 20, 23, 27, 30, 33, 36, 40, 44, + 47, 51, 55, 59, 63, 68, 72, 77, 82, 87, 92, 98, 103, 109, 115, 121, + -126, -120, -114, -108, -103, -97, -92, -87, -83, -78, -73, -69, -65, -61, -57, -53, + -50, -46, -43, -39, -36, -33, -30, -27, -24, -21, -18, -15, -12, -10, -7, -4, + -1, 1, 4, 7, 10, 13, 15, 18, 21, 24, 28, 31, 34, 37, 41, 45, + 48, 52, 56, 60, 64, 69, 73, 78, 83, 88, 93, 99, 104, 110, 116, 122, +GGML_TABLE_END() + #endif // GGML_COMMON_IMPL #endif // GGML_COMMON_IMPL diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index fe5212a20af0a..55e852afa3bce 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2778,6 +2778,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: return true; default: return false; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 34bc67acdd890..62d115f1ead46 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -271,7 +271,43 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +static __global__ void scale_f32_l(const float * x, float * dst, const void * data, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float * scale = (const float *)data; + dst[i] = scale[0] * x[i]; +} + +static void scale_f32_cuda_l(const float * x, float * dst, const void * data, const int k, cudaStream_t stream) { + constexpr int CUDA_SCALE_BLOCK_SIZE = 512; //256; + const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32_l<<>>(x, dst, data, k); +} + +void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scale; + memcpy(&scale, dst->src[1]->data, sizeof(float)); + + scale_f32_cuda_l(src0_d, dst_d, dst->src[1]->data, ggml_nelements(src0), stream); +} + void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + if (ggml_nelements(dst->src[1]) == 1 && dst->src[1]->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32) { + ggml_cuda_op_scale_tensor(ctx, dst); + return; + } ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eb39b6d23a6b3..9aff6c135b83c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -332,6 +332,12 @@ static __device__ __forceinline__ float get_alibi_slope( return powf(base, exph); } +static __device__ __forceinline__ float iq1bn_fp8_to_float(uint8_t fp8) { + typedef union { float f; uint32_t i; } scale_t; + scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18); + return s.f; +} + template struct ggml_cuda_type_traits; @@ -453,6 +459,27 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI1_M; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_IQ1BN; + static constexpr int qr = QR1_BN; + static constexpr int qi = QI1_BN; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_IQ1BN; + static constexpr int qr = QR1_BN; + static constexpr int qi = QI1_BN; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_NL; @@ -467,6 +494,41 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_XS; + static constexpr int qi = QI5_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_XS; + static constexpr int qi = QI6_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c0a4447075c6e..70305404c303e 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "convert.cuh" #include "dequantize.cuh" @@ -146,6 +153,26 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); } +template +static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq2_tn * x = (const block_iq2_tn *) vx; + + const int64_t tid = threadIdx.x; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float d = __half2float(x[i].d); + y[l+ 0] = d * ((q >> 0) & 3) - d; + y[l+32] = d * ((q >> 2) & 3) - d; + y[l+64] = d * ((q >> 4) & 3) - d; + y[l+96] = d * ((q >> 6) & 3) - d; +} + template static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -419,6 +446,66 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ } } +template +static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { + + const int64_t ii = blockIdx.x; + const block_iq1_bn * x = (const block_iq1_bn *) vx; + + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + +//#define COMPUTE_VS(v) 3*v >> 8 +#define COMPUTE_VS(v) (v + (v >> 1)) >> 7 + + const int tid = threadIdx.x; + const int il = tid/4; // 0...7 + const int ib = tid%4; // 0...3 + dst_t * y = yy + ii*QK_K + 64*ib + 8*il; + int64_t i = QK_K/QK_IQ1BN * ii + ib; + if (i >= nb64) return; + const int i16 = il/2; + uint8_t q = x[i].ql[3*i16+2*(il%2)]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = COMPUTE_VS(v); + y[2*(il%2)+j] = vs - 1; + } + q = x[i].ql[3*i16+1]; + for (int j = 0; j < 2; ++j) { + uint8_t v = k_mult[3*(il%2)+j]*q; + int8_t vs = COMPUTE_VS(v); + y[5*(1-(il%2))+j] = vs-1; + } + uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; + int8_t vs = COMPUTE_VS(v); + y[7] = vs - 1; + +#undef COMPUTE_VS +} + +template +static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { + + const int64_t ii = blockIdx.x; + const block_iq2_bn * x = (const block_iq2_bn *) vx; + + const int64_t tid = threadIdx.x; + int64_t ib64 = tid%4; // 0...3 + int64_t il = tid/4; // 0...7 + dst_t * y = yy + 256*ii + 64*ib64 + 2*il; + int64_t i = 256/QK_IQ1BN * ii + ib64; + if (i >= nb64) return; + const float m = -1; + auto qs = x[i].qs + 2*il; + for (int j = 0; j < 2; ++j) { + y[j+ 0] = ((qs[j] >> 0) & 3) + m; + y[j+16] = ((qs[j] >> 2) & 3) + m; + y[j+32] = ((qs[j] >> 4) & 3) + m; + y[j+48] = ((qs[j] >> 6) & 3) + m; + } +} + + template static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -454,6 +541,139 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst } } +template +static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int64_t i = blockIdx.x; + const block_iq4_k * x = (const block_iq4_k *)vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d; + const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2); + const float d1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32); + const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); + const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1); + const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d1 * values1[q4[j] & 0xf]; + y[j+16] = d2 * values2[q4[j] >> 4]; + } +} + +template +static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq5_k * x = (const block_iq5_k *) vx; + + const int tid = threadIdx.x; + int ib64 = tid/8; // 0...3 + int il = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 64*ib64 + 2*il; + const float d = (float)x[i].d; + const float dl1 = d * (((x[i].scales_l[2*ib64+0] & 0xf) | ((x[i].scales_h[ib64] << 4) & 0x30)) - 32); + const float dl2 = d * (((x[i].scales_l[2*ib64+0] >> 4) | ((x[i].scales_h[ib64] << 2) & 0x30)) - 32); + const float dl3 = d * (((x[i].scales_l[2*ib64+1] & 0xf) | ((x[i].scales_h[ib64] >> 0) & 0x30)) - 32); + const float dl4 = d * (((x[i].scales_l[2*ib64+1] >> 4) | ((x[i].scales_h[ib64] >> 2) & 0x30)) - 32); + const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; + const uint8_t * qh = x[i].qh + 2*il; + const uint8_t extra = x[i].extra >> 4*(ib64%4); + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; + y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; + y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; + y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + } +} + +template +static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq6_k * x = (const block_iq6_k *) vx; + + const int tid = threadIdx.x; + int ib64 = tid/8; // 0...3 + int il = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 64*ib64 + 2*il; + const float d = (float)x[i].d; + const float dl1 = d * x[i].scales[4*ib64+0]; + const float dl2 = d * x[i].scales[4*ib64+1]; + const float dl3 = d * x[i].scales[4*ib64+2]; + const float dl4 = d * x[i].scales[4*ib64+3]; + const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; + const uint8_t * qh = x[i].qh + 32*(ib64/2) + 2*il; + const uint8_t extra = x[i].extra >> 4*(ib64%4); + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 4*(ib64%2), h2 = qh[j+16] >> 4*(ib64%2); + uint8_t q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4); + uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); + uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); + uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); + y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); + y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); + y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); + y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + } +} + +template +static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_k * x = (const block_iq2_k *) vx; + + const int tid = threadIdx.x; + int ib128 = tid/16; // 0 or 1 + int il = tid%16; // 0...15 + dst_t * y = yy + i*QK_K + 128*ib128 + 2*il; + const float d = (float)x[i].d * 1.025f; //1.0325f; + const float dl1 = d * (2*((x[i].scales[4*ib128+0] >> 4*(il/8)) & 0xf) - 15); + const float dl2 = d * (2*((x[i].scales[4*ib128+1] >> 4*(il/8)) & 0xf) - 15); + const float dl3 = d * (2*((x[i].scales[4*ib128+2] >> 4*(il/8)) & 0xf) - 15); + const float dl4 = d * (2*((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 15); + const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; + const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + } +} + +template +static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq3_k * x = (const block_iq3_k *) vx; + + const int tid = threadIdx.x; + int ib128 = tid/16; // 0 or 1 + int il = tid%16; // 0...15 + dst_t * y = yy + i*QK_K + 128*ib128 + 2*il; + const float d = (float)x[i].d * 1.01f; //1.0125f; + const uint16_t sh = x[i].scales_h >> (8*ib128 + (il/8)); + const float dl1 = d * ((2*((x[i].scales_l[4*ib128+0] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x01) ? -1 : 1)); + const float dl2 = d * ((2*((x[i].scales_l[4*ib128+1] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x04) ? -1 : 1)); + const float dl3 = d * ((2*((x[i].scales_l[4*ib128+2] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x10) ? -1 : 1)); + const float dl4 = d * ((2*((x[i].scales_l[4*ib128+3] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x40) ? -1 : 1)); + const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; + const uint8_t * qh = x[i].qh + 2*il; + const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; + y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; + y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; + y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + } +} + template static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); @@ -477,6 +697,12 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k dequantize_block_q2_K<<>>(vx, y); } +template +static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_tn<<>>(vx, y); +} + template static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; @@ -563,12 +789,56 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq1_m<<>>(vx, y); } +template +static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb64 = k / QK_IQ1BN; + const int nb = (k + 255) / 256; + dequantize_block_iq1_bn<<>>(vx, y, nb64); +} + +template +static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb64 = k / QK_IQ1BN; + const int nb = (k + 255) / 256; + dequantize_block_iq2_bn<<>>(vx, y, nb64); +} + template static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq4_xs<<>>(vx, y); } +template +static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq2_k<<>>(vx, y); +} + +template +static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq3_k<<>>(vx, y); +} + +template +static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_k<<>>(vx, y); +} + +template +static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq5_k<<>>(vx, y); +} + +template +static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq6_k<<>>(vx, y); +} + template static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -605,6 +875,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; + case GGML_TYPE_IQ2_TN: + return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -625,10 +897,24 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ1_BN: + return dequantize_row_iq1_bn_cuda; + case GGML_TYPE_IQ2_BN: + return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; + case GGML_TYPE_IQ5_K: + return dequantize_row_iq5_k_cuda; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -652,6 +938,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; + case GGML_TYPE_IQ2_TN: + return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -672,10 +960,24 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ1_BN: + return dequantize_row_iq1_bn_cuda; + case GGML_TYPE_IQ2_BN: + return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; + case GGML_TYPE_IQ5_K: + return dequantize_row_iq5_k_cuda; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu new file mode 100644 index 0000000000000..c567ad1ae6e07 --- /dev/null +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -0,0 +1,585 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "iqk_mmvq.cuh" + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +namespace { +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__global__ void iqk_mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int nwarps = 1; + constexpr int rows_per_cuda_block = 1; +#else + constexpr int nwarps = ncols_y <= 4 ? 4 : 2; + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; + +// partial sum for each thread + float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; + + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + // sum up partial sums and write back result +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); + } + + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + } + } +} + +template +void iqk_mul_mat_vec_q_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); + //GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); + + int id = ggml_cuda_get_device(); + + int64_t nwarps = 1; + int64_t rows_per_cuda_block = 1; + + if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + switch(ncols_y) { + case 1: + nwarps = 4; + rows_per_cuda_block = 1; + break; + case 2: + case 3: + case 4: + nwarps = 4; + rows_per_cuda_block = 2; + break; + case 5: + case 6: + case 7: + case 8: + nwarps = 2; + rows_per_cuda_block = 2; + break; + default: + GGML_ASSERT(false); + break; + } + } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; + const dim3 block_nums(nblocks, 1, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + switch (ncols_y) { + case 1: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 2: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 3: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 4: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 5: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 6: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 7: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 8: + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + default: + GGML_ASSERT(false); + break; + } +} + +__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + const uint8_t * values = all_values + 16*(shift & 1); + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + values = all_values + 8*(shift & 2); + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +#define VDR_IQ4_K_Q8_1_MMVQ 4 +#define VDR_IQ4_K_Q8_1_MMQ 4 + +__device__ __forceinline__ float vec_dot_iq4_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq4k_values; + + // iqs is 0...28 + const int ib32 = iqs/4; + // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; + const uint16_t extra = bq4->extra >> 2*ib32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); + get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); + sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); + } + const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); + const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); + const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; + return d * (sumi1 * ls1 + sumi2 * ls2); +} + +#define VDR_IQ5_K_Q8_1_MMVQ 4 +#define VDR_IQ5_K_Q8_1_MMQ 4 + +__device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t * values) { + uint16_t v1 = values[a8[0]] | (values[a8[1]] << 8); + uint16_t v2 = values[a8[2]] | (values[a8[3]] << 8); + return v1 | (v2 << 16); +} + +__device__ __forceinline__ float vec_dot_iq5_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + + const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq5nl_values; + + int i4 = iqs/4; // 0...7. Blocks of 16 index is 4*(i4/2) + (i4%2) + (0 and 2) + + const int32_t * q8_1 = (const int *)bq8_1[2*(i4/2)+0].qs + 4*(i4%2); + const int32_t * q8_2 = (const int *)bq8_1[2*(i4/2)+1].qs + 4*(i4%2); + const uint32_t * q4 = (const uint32_t *)bq5->qs + 8*(i4/2) + 4*(i4%2); + const uint32_t * qh = (const uint32_t *)bq5->qh + 4*(i4%2); + const uint16_t extra = bq5->extra >> (4*(i4/2) + (i4%2)); + const uint8_t * values1 = all_values + 32*(extra & 1); + const uint8_t * values2 = all_values + 8*(extra & 4); + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)aux32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + uint32_t h = qh[j] >> 2*(i4/2); + aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + v1 = int_from_table(a8+0, values1); + v2 = int_from_table(a8+4, values2); + sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2); + } + const float d5 = __half2float(bq5->d); + const uint8_t sh = bq5->scales_h[i4/2] >> 2*(i4%2); + const int ls1 = (((bq5->scales_l[2*(i4/2)+0] >> 4*(i4%2)) & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = (((bq5->scales_l[2*(i4/2)+1] >> 4*(i4%2)) & 0xf) | ((sh << 0) & 0x30)) - 32; + return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2); +} + +#define VDR_IQ6_K_Q8_1_MMVQ 4 +#define VDR_IQ6_K_Q8_1_MMQ 4 + +__device__ __forceinline__ float vec_dot_iq6_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + + const block_iq6_k * bq6 = (const block_iq6_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq6nl_values; + + int i4 = iqs/4; // 0...7. Blocks of 16 index is 4*(i4/2) + (i4%2) + (0 and 2) + // Blocks of 32 index is 2*(i4/2) + 0 or 1 + + const int32_t * q8_1 = (const int *)bq8_1[2*(i4/2)+0].qs + 4*(i4%2); + const int32_t * q8_2 = (const int *)bq8_1[2*(i4/2)+1].qs + 4*(i4%2); + const uint32_t * q4 = (const uint32_t *)bq6->qs + 8*(i4/2) + 4*(i4%2); + const uint32_t * qh = (const uint32_t *)bq6->qh + 8*(i4/4) + 4*(i4%2); + const uint16_t extra = bq6->extra >> (4*(i4/2) + (i4%2)); + const uint8_t * values1 = all_values + 64*(extra & 1); + const uint8_t * values2 = all_values + 16*(extra & 4); + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)aux32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + uint32_t h = qh[j] >> 4*((i4/2)%2); + aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x30303030); + aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 2) & 0x30303030); + v1 = int_from_table(a8+0, values1); + v2 = int_from_table(a8+4, values2); + sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2); + } + const float d6 = __half2float(bq6->d); + return d6 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * bq6->scales[4*(i4/2)+(i4%2)] + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * bq6->scales[4*(i4/2)+(i4%2)+2]); +} + +static const __device__ uint32_t iq2k_table[512] = { + 0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311, + 0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111, + 0xe1f3e1e1, 0xe1f3e1f3, 0xe1f3e101, 0xe1f3e111, 0xe1f3f3e1, 0xe1f3f3f3, 0xe1f3f301, 0xe1f3f311, + 0xe1f301e1, 0xe1f301f3, 0xe1f30101, 0xe1f30111, 0xe1f311e1, 0xe1f311f3, 0xe1f31101, 0xe1f31111, + 0xe101e1e1, 0xe101e1f3, 0xe101e101, 0xe101e111, 0xe101f3e1, 0xe101f3f3, 0xe101f301, 0xe101f311, + 0xe10101e1, 0xe10101f3, 0xe1010101, 0xe1010111, 0xe10111e1, 0xe10111f3, 0xe1011101, 0xe1011111, + 0xe111e1e1, 0xe111e1f3, 0xe111e101, 0xe111e111, 0xe111f3e1, 0xe111f3f3, 0xe111f301, 0xe111f311, + 0xe11101e1, 0xe11101f3, 0xe1110101, 0xe1110111, 0xe11111e1, 0xe11111f3, 0xe1111101, 0xe1111111, + 0xf3e1e1e1, 0xf3e1e1f3, 0xf3e1e101, 0xf3e1e111, 0xf3e1f3e1, 0xf3e1f3f3, 0xf3e1f301, 0xf3e1f311, + 0xf3e101e1, 0xf3e101f3, 0xf3e10101, 0xf3e10111, 0xf3e111e1, 0xf3e111f3, 0xf3e11101, 0xf3e11111, + 0xf3f3e1e1, 0xf3f3e1f3, 0xf3f3e101, 0xf3f3e111, 0xf3f3f3e1, 0xf3f3f3f3, 0xf3f3f301, 0xf3f3f311, + 0xf3f301e1, 0xf3f301f3, 0xf3f30101, 0xf3f30111, 0xf3f311e1, 0xf3f311f3, 0xf3f31101, 0xf3f31111, + 0xf301e1e1, 0xf301e1f3, 0xf301e101, 0xf301e111, 0xf301f3e1, 0xf301f3f3, 0xf301f301, 0xf301f311, + 0xf30101e1, 0xf30101f3, 0xf3010101, 0xf3010111, 0xf30111e1, 0xf30111f3, 0xf3011101, 0xf3011111, + 0xf311e1e1, 0xf311e1f3, 0xf311e101, 0xf311e111, 0xf311f3e1, 0xf311f3f3, 0xf311f301, 0xf311f311, + 0xf31101e1, 0xf31101f3, 0xf3110101, 0xf3110111, 0xf31111e1, 0xf31111f3, 0xf3111101, 0xf3111111, + 0x01e1e1e1, 0x01e1e1f3, 0x01e1e101, 0x01e1e111, 0x01e1f3e1, 0x01e1f3f3, 0x01e1f301, 0x01e1f311, + 0x01e101e1, 0x01e101f3, 0x01e10101, 0x01e10111, 0x01e111e1, 0x01e111f3, 0x01e11101, 0x01e11111, + 0x01f3e1e1, 0x01f3e1f3, 0x01f3e101, 0x01f3e111, 0x01f3f3e1, 0x01f3f3f3, 0x01f3f301, 0x01f3f311, + 0x01f301e1, 0x01f301f3, 0x01f30101, 0x01f30111, 0x01f311e1, 0x01f311f3, 0x01f31101, 0x01f31111, + 0x0101e1e1, 0x0101e1f3, 0x0101e101, 0x0101e111, 0x0101f3e1, 0x0101f3f3, 0x0101f301, 0x0101f311, + 0x010101e1, 0x010101f3, 0x01010101, 0x01010111, 0x010111e1, 0x010111f3, 0x01011101, 0x01011111, + 0x0111e1e1, 0x0111e1f3, 0x0111e101, 0x0111e111, 0x0111f3e1, 0x0111f3f3, 0x0111f301, 0x0111f311, + 0x011101e1, 0x011101f3, 0x01110101, 0x01110111, 0x011111e1, 0x011111f3, 0x01111101, 0x01111111, + 0x11e1e1e1, 0x11e1e1f3, 0x11e1e101, 0x11e1e111, 0x11e1f3e1, 0x11e1f3f3, 0x11e1f301, 0x11e1f311, + 0x11e101e1, 0x11e101f3, 0x11e10101, 0x11e10111, 0x11e111e1, 0x11e111f3, 0x11e11101, 0x11e11111, + 0x11f3e1e1, 0x11f3e1f3, 0x11f3e101, 0x11f3e111, 0x11f3f3e1, 0x11f3f3f3, 0x11f3f301, 0x11f3f311, + 0x11f301e1, 0x11f301f3, 0x11f30101, 0x11f30111, 0x11f311e1, 0x11f311f3, 0x11f31101, 0x11f31111, + 0x1101e1e1, 0x1101e1f3, 0x1101e101, 0x1101e111, 0x1101f3e1, 0x1101f3f3, 0x1101f301, 0x1101f311, + 0x110101e1, 0x110101f3, 0x11010101, 0x11010111, 0x110111e1, 0x110111f3, 0x11011101, 0x11011111, + 0x1111e1e1, 0x1111e1f3, 0x1111e101, 0x1111e111, 0x1111f3e1, 0x1111f3f3, 0x1111f301, 0x1111f311, + 0x111101e1, 0x111101f3, 0x11110101, 0x11110111, 0x111111e1, 0x111111f3, 0x11111101, 0x11111111, + 0xe6e6e6e6, 0xe6e6e6f8, 0xe6e6e606, 0xe6e6e616, 0xe6e6f8e6, 0xe6e6f8f8, 0xe6e6f806, 0xe6e6f816, + 0xe6e606e6, 0xe6e606f8, 0xe6e60606, 0xe6e60616, 0xe6e616e6, 0xe6e616f8, 0xe6e61606, 0xe6e61616, + 0xe6f8e6e6, 0xe6f8e6f8, 0xe6f8e606, 0xe6f8e616, 0xe6f8f8e6, 0xe6f8f8f8, 0xe6f8f806, 0xe6f8f816, + 0xe6f806e6, 0xe6f806f8, 0xe6f80606, 0xe6f80616, 0xe6f816e6, 0xe6f816f8, 0xe6f81606, 0xe6f81616, + 0xe606e6e6, 0xe606e6f8, 0xe606e606, 0xe606e616, 0xe606f8e6, 0xe606f8f8, 0xe606f806, 0xe606f816, + 0xe60606e6, 0xe60606f8, 0xe6060606, 0xe6060616, 0xe60616e6, 0xe60616f8, 0xe6061606, 0xe6061616, + 0xe616e6e6, 0xe616e6f8, 0xe616e606, 0xe616e616, 0xe616f8e6, 0xe616f8f8, 0xe616f806, 0xe616f816, + 0xe61606e6, 0xe61606f8, 0xe6160606, 0xe6160616, 0xe61616e6, 0xe61616f8, 0xe6161606, 0xe6161616, + 0xf8e6e6e6, 0xf8e6e6f8, 0xf8e6e606, 0xf8e6e616, 0xf8e6f8e6, 0xf8e6f8f8, 0xf8e6f806, 0xf8e6f816, + 0xf8e606e6, 0xf8e606f8, 0xf8e60606, 0xf8e60616, 0xf8e616e6, 0xf8e616f8, 0xf8e61606, 0xf8e61616, + 0xf8f8e6e6, 0xf8f8e6f8, 0xf8f8e606, 0xf8f8e616, 0xf8f8f8e6, 0xf8f8f8f8, 0xf8f8f806, 0xf8f8f816, + 0xf8f806e6, 0xf8f806f8, 0xf8f80606, 0xf8f80616, 0xf8f816e6, 0xf8f816f8, 0xf8f81606, 0xf8f81616, + 0xf806e6e6, 0xf806e6f8, 0xf806e606, 0xf806e616, 0xf806f8e6, 0xf806f8f8, 0xf806f806, 0xf806f816, + 0xf80606e6, 0xf80606f8, 0xf8060606, 0xf8060616, 0xf80616e6, 0xf80616f8, 0xf8061606, 0xf8061616, + 0xf816e6e6, 0xf816e6f8, 0xf816e606, 0xf816e616, 0xf816f8e6, 0xf816f8f8, 0xf816f806, 0xf816f816, + 0xf81606e6, 0xf81606f8, 0xf8160606, 0xf8160616, 0xf81616e6, 0xf81616f8, 0xf8161606, 0xf8161616, + 0x06e6e6e6, 0x06e6e6f8, 0x06e6e606, 0x06e6e616, 0x06e6f8e6, 0x06e6f8f8, 0x06e6f806, 0x06e6f816, + 0x06e606e6, 0x06e606f8, 0x06e60606, 0x06e60616, 0x06e616e6, 0x06e616f8, 0x06e61606, 0x06e61616, + 0x06f8e6e6, 0x06f8e6f8, 0x06f8e606, 0x06f8e616, 0x06f8f8e6, 0x06f8f8f8, 0x06f8f806, 0x06f8f816, + 0x06f806e6, 0x06f806f8, 0x06f80606, 0x06f80616, 0x06f816e6, 0x06f816f8, 0x06f81606, 0x06f81616, + 0x0606e6e6, 0x0606e6f8, 0x0606e606, 0x0606e616, 0x0606f8e6, 0x0606f8f8, 0x0606f806, 0x0606f816, + 0x060606e6, 0x060606f8, 0x06060606, 0x06060616, 0x060616e6, 0x060616f8, 0x06061606, 0x06061616, + 0x0616e6e6, 0x0616e6f8, 0x0616e606, 0x0616e616, 0x0616f8e6, 0x0616f8f8, 0x0616f806, 0x0616f816, + 0x061606e6, 0x061606f8, 0x06160606, 0x06160616, 0x061616e6, 0x061616f8, 0x06161606, 0x06161616, + 0x16e6e6e6, 0x16e6e6f8, 0x16e6e606, 0x16e6e616, 0x16e6f8e6, 0x16e6f8f8, 0x16e6f806, 0x16e6f816, + 0x16e606e6, 0x16e606f8, 0x16e60606, 0x16e60616, 0x16e616e6, 0x16e616f8, 0x16e61606, 0x16e61616, + 0x16f8e6e6, 0x16f8e6f8, 0x16f8e606, 0x16f8e616, 0x16f8f8e6, 0x16f8f8f8, 0x16f8f806, 0x16f8f816, + 0x16f806e6, 0x16f806f8, 0x16f80606, 0x16f80616, 0x16f816e6, 0x16f816f8, 0x16f81606, 0x16f81616, + 0x1606e6e6, 0x1606e6f8, 0x1606e606, 0x1606e616, 0x1606f8e6, 0x1606f8f8, 0x1606f806, 0x1606f816, + 0x160606e6, 0x160606f8, 0x16060606, 0x16060616, 0x160616e6, 0x160616f8, 0x16061606, 0x16061616, + 0x1616e6e6, 0x1616e6f8, 0x1616e606, 0x1616e616, 0x1616f8e6, 0x1616f8f8, 0x1616f806, 0x1616f816, + 0x161606e6, 0x161606f8, 0x16160606, 0x16160616, 0x161616e6, 0x161616f8, 0x16161606, 0x16161616, +}; + +__device__ __forceinline__ int int_from_table_4(const uint8_t * a8, const int * values) { + return values[a8[0] | (a8[1] << 2) | (a8[2] << 4) | (a8[3] << 6)]; +} + +#define VDR_IQ2_K_Q8_1_MMVQ 4 +#define VDR_IQ2_K_Q8_1_MMQ 4 + +__device__ __forceinline__ float vec_dot_iq2_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + // iqs is 0, 4, 8, 12, 16, 20, 24, 28 + // we have 16 packed quants (when cast to int) + + int i4 = iqs/4; // 0...7. We will process q8 blocks 4*(i4/4), 4*(i4/4)+1, 4*(i4/4)+2, 4*(i4/4)+3 + const int32_t * q8_1 = (const int *)bq8_1[4*(i4/4)+0].qs + 2*(i4%4); + const int32_t * q8_2 = (const int *)bq8_1[4*(i4/4)+1].qs + 2*(i4%4); + const int32_t * q8_3 = (const int *)bq8_1[4*(i4/4)+2].qs + 2*(i4%4); + const int32_t * q8_4 = (const int *)bq8_1[4*(i4/4)+3].qs + 2*(i4%4); + + const block_iq2_k * bq2 = (const block_iq2_k *) vbq + kbx; + const uint32_t * q2 = (const uint32_t *)bq2->qs + 8*(i4/4) + 2*(i4%4); + const uint16_t extra = bq2->extra >> (8*(i4/4) + (i4%4)/2); + + const int * all_values = (const int *)iq2k_table; + const int * values; + + uint32_t val1 = q2[0], val2 = q2[1]; + + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)&aux32; + int v1, v2; + + // Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2 + // -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2) + + const uint32_t * scales = (const uint32_t *)bq2->scales; + uint32_t s32 = __vsub4(((scales[i4/4] >> 4*(((i4%4)/2)%2)) & 0x0f0f0f0f) << 1, 0x0f0f0f0f); + const int8_t * s8 = (const int8_t *)&s32; + + aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8); + v1 = int_from_table_4(a8 + 0, values); + v2 = int_from_table_4(a8 + 4, values); + int sumi1 = ggml_cuda_dp4a(v2, q8_1[1], ggml_cuda_dp4a(v1, q8_1[0], 0)) * s8[0]; + + aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x04) << 6); + v1 = int_from_table_4(a8 + 0, values); + v2 = int_from_table_4(a8 + 4, values); + int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[1]; + + aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x10) << 4); + v1 = int_from_table_4(a8 + 0, values); + v2 = int_from_table_4(a8 + 4, values); + int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[2]; + + aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x40) << 2); + v1 = int_from_table_4(a8 + 0, values); + v2 = int_from_table_4(a8 + 4, values); + int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3]; + + return __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1 + + __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2 + + __low2float(bq8_1[4*(i4/4)+2].ds) * sumi3 + + __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4); + +} + +#define VDR_IQ3_K_Q8_1_MMVQ 4 +#define VDR_IQ3_K_Q8_1_MMQ 4 + +static const __device__ uint16_t iq3k_table[128] = { + 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f, + 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f, + 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f, + 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f, + 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33, + 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33, + 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133, + 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333, +}; + +__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) { + return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16); +} + +__device__ __forceinline__ float vec_dot_iq3_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs) { + const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx; + + int iqs = iiqs/4; + const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255 + // Each thread processes 8 quants in each of the 4 32-blocks + const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31 + const int shift = 4*(il8/2); + + const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8; + const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8; + + uint32_t aux32; + const uint8_t * aux8 = (const uint8_t *)&aux32; + + const int hshift = 4*(1-ib128); + const uint16_t sh = bq3->scales_h >> (8*ib128 + il8/2); + + const uint8_t extra = bq3->extra >> (8*ib128 + il8/2); + const uint16_t * values1 = iq3k_table + ((extra << 6) & 0x40); + const uint16_t * values2 = iq3k_table + ((extra << 5) & 0x40); + const uint16_t * values3 = iq3k_table + ((extra << 4) & 0x40); + const uint16_t * values4 = iq3k_table + ((extra << 3) & 0x40); + + const int * q8; + int sumi[4] = {0, 0, 0, 0}; + int v; + for (int i = 0; i < 2; ++i) { + uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16); + uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift) >> 2; + + q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8; + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values1); + sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]); + vl >>= 2; vh >>= 1; + + q8 += sizeof(block_q8_1)/4; + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values2); + sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]); + vl >>= 2; vh >>= 1; + + q8 += sizeof(block_q8_1)/4; + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values3); + sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]); + vl >>= 2; vh >>= 1; + + q8 += sizeof(block_q8_1)/4; + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values4); + sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]); + + } + const float d = __half2float(bq3->d); + const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2*ib128; + aux32 = ((((sl16[0] | (sl16[1] << 16)) >> shift) & 0x0f0f0f0f) << 1) | 0x01010101; + return d * (__low2float(bq8_1[4*ib128+0].ds) * aux8[0] * (sh & 0x01 ? -1 : 1) * sumi[0] + + __low2float(bq8_1[4*ib128+1].ds) * aux8[1] * (sh & 0x04 ? -1 : 1) * sumi[1] + + __low2float(bq8_1[4*ib128+2].ds) * aux8[2] * (sh & 0x10 ? -1 : 1) * sumi[2] + + __low2float(bq8_1[4*ib128+3].ds) * aux8[3] * (sh & 0x40 ? -1 : 1) * sumi[3]); + +} + +#define VDR_IQ2_TN_Q8_1_MMVQ 1 +#define VDR_IQ2_TN_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + + const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs; + int v = q16[0] | (q16[1] << 16); + + float sumf = 0; + for (int i = 0; i < QR2_K; ++ i) { + int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); + float d8 = __low2float(bq8_1[bq8_offset + i].ds); + sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0)); + v >>= 2; + } + return __half2float(bq2->d) * sumf; + + //float sumf_d = 0; + //float sumf_m = 0; + //for (int i = 0; i < QR2_K; ++ i) { + // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); + // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds); + // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0); + // sumf_m += d8.y; + // v >>= 2; + //} + //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m); + +} + +} // namespace + +void mul_mat_vec_iq2_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq3_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq4_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq5_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq6_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq2_tn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh new file mode 100644 index 0000000000000..7af8e570986aa --- /dev/null +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -0,0 +1,26 @@ +#include "common.cuh" + +void mul_mat_vec_iq2_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq3_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq4_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq5_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq6_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq2_tn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 7dbbc993903c3..2586ab7ed77c7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,5 +1,6 @@ #include "mmvq.cuh" #include "vecdotq.cuh" +#include "iqk_mmvq.cuh" typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); @@ -20,6 +21,8 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : + type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 : + type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : @@ -313,6 +316,20 @@ static void mul_mat_vec_iq1_m_q8_1_cuda( mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_iq1_bn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +static void mul_mat_vec_iq2_bn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -403,12 +420,36 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ1_M: mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ1_BN: + mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ2_BN: + mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ2_TN: + mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_K: + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ3_K: + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ4_K: + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ5_K: + mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ6_K: + mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40091a0ef07b4..b1b465a3b4f80 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include @@ -1066,6 +1073,95 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } +static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; + + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + + // iqs is 0 or 1 + + int sumi = 0; +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const int * q8 = (const int *)bq8_1[iqs].qs; + int val[4]; + for (int l = 0; l < 2; ++l) { + int8_t * a = (int8_t *)val; + const int i16 = 2*iqs + l; + for (int k = 0; k < 3; ++k) { + uint8_t q = bq1->ql[3*i16+k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + *a++ = vs-1; + } + } + uint8_t v = k_mult[i16]*bq1->extra; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + *a++ = vs-1; + sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi)))); + } +#else + const int8_t * q8 = bq8_1[iqs].qs; + for (int l = 0; l < 2; ++l) { + const int i16 = 2*iqs + l; + for (int k = 0; k < 3; ++k) { + uint8_t q = bq1->ql[3*i16+k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = (v + (v >> 1)) >> 7; + sumi += q8[j]*(vs - 1); + } + q8 += 5; + } + uint8_t v = k_mult[i16]*bq1->extra; + int8_t vs = (v + (v >> 1)) >> 7; + sumi += q8[0]*(vs - 1); + q8++; + } +#endif + return __low2float(bq8_1[iqs].ds) * sumi; +} + +static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx; + + // iqs is 0 or 1 + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + auto qs = (const uint16_t *)bq2->qs + 4*iqs; + auto q8l = (const int *)bq8_1[0].qs + 2*iqs; + auto q8h = (const int *)bq8_1[1].qs + 2*iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + for (int j = 0; j < 2; ++j) { + int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16); + int vh = vl >> 4; + sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1); + sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2); + sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3); + sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); + return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y; +#else + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + auto q8l = bq8_1[0].qs + 8*iqs; + auto q8h = bq8_1[1].qs + 8*iqs; + auto qs = bq2->qs + 8*iqs; + for (int j = 0; j < 8; ++j) { + sumi1 += q8l[j+ 0] * (qs[j] & 0x03); + sumi2 += q8l[j+16] * (qs[j] & 0x0c); + sumi3 += q8h[j+ 0] * (qs[j] & 0x30); + sumi4 += q8h[j+16] * (qs[j] & 0xc0); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); + return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y; +#endif +} + static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; const int8_t * q0_8 = (const int8_t *) &q0_32; @@ -1131,3 +1227,28 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds); return d * sumi; } + +static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + const uint8_t * values = all_values + 16*(shift & 1); + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + values = all_values + 8*(shift & 2); + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +#define VDR_IQ4_K_Q8_1_MMVQ 4 +#define VDR_IQ4_K_Q8_1_MMQ 4 + +#define VDR_IQ5_K_Q8_1_MMVQ 4 +#define VDR_IQ5_K_Q8_1_MMQ 4 + +#define VDR_IQ2_K_Q8_1_MMVQ 4 +#define VDR_IQ2_K_Q8_1_MMQ 4 diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index d825b6b69e37f..2b9ae2207df10 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #import "ggml-metal.h" #import "ggml-backend-impl.h" @@ -30,10 +37,13 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, + GGML_METAL_KERNEL_TYPE_ADD_4, GGML_METAL_KERNEL_TYPE_ADD_ROW, GGML_METAL_KERNEL_TYPE_MUL, + GGML_METAL_KERNEL_TYPE_MUL_4, GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, + GGML_METAL_KERNEL_TYPE_DIV_4, GGML_METAL_KERNEL_TYPE_DIV_ROW, GGML_METAL_KERNEL_TYPE_REPEAT_F32, GGML_METAL_KERNEL_TYPE_REPEAT_F16, @@ -76,8 +86,16 @@ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, @@ -104,8 +122,16 @@ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, @@ -128,8 +154,16 @@ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, @@ -149,8 +183,16 @@ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, @@ -170,8 +212,16 @@ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, @@ -490,10 +540,13 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_4, div_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); @@ -536,8 +589,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, get_rows_iq2_tn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, get_rows_iq2_k, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K, get_rows_iq3_k, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, get_rows_iq4_k, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K, get_rows_iq5_k, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, get_rows_iq6_k, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); @@ -564,8 +625,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, mul_mv_iq2_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, mul_mv_iq2_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32, mul_mv_iq3_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, mul_mv_iq4_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32, mul_mv_iq5_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32, mul_mv_iq6_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); @@ -588,8 +657,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, mul_mv_id_iq2_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, mul_mv_id_iq2_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32, mul_mv_id_iq3_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, mul_mv_id_iq4_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32, mul_mv_id_iq5_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, mul_mv_id_iq6_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); @@ -609,8 +686,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, mul_mm_iq2_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, mul_mm_iq2_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32, mul_mm_iq3_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); @@ -630,8 +715,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, mul_mm_id_iq2_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, mul_mm_id_iq2_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32, mul_mm_id_iq3_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, mul_mm_id_iq4_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32, mul_mm_id_iq5_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32, mul_mm_id_iq6_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -1073,7 +1166,53 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { + float scale; + memcpy(&scale, src1->data, sizeof(float)); + //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && + dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && + dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && + dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && + dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { + + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; + default: GGML_ASSERT(false); + } + + int64_t n = ggml_nelements(dst)/4; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(src0)); // src1 is a row @@ -1623,8 +1762,16 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; default: GGML_ABORT("MUL MAT-MAT not implemented"); } @@ -1783,6 +1930,24 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_TN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32].pipeline; + } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -1795,6 +1960,36 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_IQ2_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32].pipeline; + } break; + case GGML_TYPE_IQ3_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32].pipeline; + } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32].pipeline; + } break; + case GGML_TYPE_IQ5_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32].pipeline; + } break; + case GGML_TYPE_IQ6_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1825,7 +2020,9 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| + src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -1838,8 +2035,9 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K) { + const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1925,8 +2123,16 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; default: GGML_ABORT("MUL_MAT_ID not implemented"); } @@ -2079,6 +2285,24 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_TN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32].pipeline; + } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -2091,6 +2315,36 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_IQ2_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; + } break; + case GGML_TYPE_IQ3_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; + } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; + } break; + case GGML_TYPE_IQ5_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; + } break; + case GGML_TYPE_IQ6_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -2132,7 +2386,9 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| + src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2145,8 +2401,9 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K) { + const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -2191,8 +2448,16 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; + case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ABORT("not implemented"); } diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3bb37d32aced0..904639a59c1a9 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #define GGML_COMMON_DECL_METAL #define GGML_COMMON_IMPL_METAL #include "ggml-common.h" @@ -225,6 +232,13 @@ kernel void kernel_add_row( uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } +kernel void kernel_add_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig]; +} kernel void kernel_mul_row( device const float4 * src0, @@ -235,6 +249,14 @@ kernel void kernel_mul_row( dst[tpig] = src0[tpig] * src1[tpig % nb]; } +kernel void kernel_mul_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig]; +} + kernel void kernel_div_row( device const float4 * src0, device const float4 * src1, @@ -243,6 +265,13 @@ kernel void kernel_div_row( uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] / src1[tpig % nb]; } +kernel void kernel_div_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] / src1[tpig]; +} kernel void kernel_scale( device const float * src0, @@ -1480,7 +1509,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( +static inline void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation @@ -1634,35 +1663,69 @@ kernel void kernel_rope_neox( const float theta_base = (float) pos[i2]; const float inv_ndims = -1.f/n_dims; + float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims); + const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims); + float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + int64_t i0 = 2*tiitg; + for ( ; i0 < n_dims; i0 += 2*tptg.x) { + const int64_t ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - const float x0 = src[0]; - const float x1 = src[n_dims/2]; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + theta *= theta_multiplier; + } + for ( ; i0 < ne0; i0 += 2*tptg.x) { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } + + // Original version + //for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + // if (i0 < n_dims) { + // const int64_t ic = i0/2; + + // // Who thought that having a pow() evaluation in a loop is a good idea? + // //const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + // const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + // rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + // const float x0 = src[0]; + // const float x1 = src[n_dims/2]; + + // dst_data[0] = x0*cos_theta - x1*sin_theta; + // dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + + // theta *= theta_multiplier; + // } else { + // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + // dst_data[0] = src[0]; + // dst_data[1] = src[1]; + // } + //} } typedef decltype(kernel_rope_norm) kernel_rope_norm_t; @@ -2993,6 +3056,32 @@ constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; +constexpr constant static float kvalues_iq4k_f[32] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f, + -123.f, -100.f, -79.f, -61.f, -45.f, -31.f, -18.f, -6.f, 5.f, 17.f, 29.f, 42.f, 57.f, 73.f, 93.f, 117.f, +}; + +constexpr constant static float kvalues_iq5k_f[64] = { + -126.f, -114.f, -103.f, -92.f, -83.f, -74.f, -65.f, -57.f, -50.f, -43.f, -36.f, -30.f, -24.f, -18.f, -12.f, -6.f, -1.f, 5.f, 11.f, 17.f, 23.f, 29.f, 36.f, 43.f, 51.f, 59.f, 68.f, 77.f, 87.f, 97.f, 109.f, 121.f, + -124.f, -112.f, -101.f, -90.f, -81.f, -72.f, -63.f, -55.f, -48.f, -41.f, -34.f, -28.f, -22.f, -16.f, -10.f, -4.f, 1.f, 7.f, 13.f, 19.f, 25.f, 31.f, 38.f, 45.f, 53.f, 61.f, 70.f, 79.f, 89.f, 99.f, 111.f, 123.f, +}; + +constexpr constant static float kvalues_iq6k_f[128] = { + -127.f, -121.f, -115.f, -109.f, -104.f, -98.f, -93.f, -88.f, -84.f, -79.f, -74.f, -70.f, -66.f, -62.f, -58.f, -54.f, + -51.f, -47.f, -44.f, -40.f, -37.f, -34.f, -31.f, -28.f, -25.f, -22.f, -19.f, -16.f, -13.f, -11.f, -8.f, -5.f, + -2.f, 0.f, 3.f, 6.f, 9.f, 12.f, 14.f, 17.f, 20.f, 23.f, 27.f, 30.f, 33.f, 36.f, 40.f, 44.f, + 47.f, 51.f, 55.f, 59.f, 63.f, 68.f, 72.f, 77.f, 82.f, 87.f, 92.f, 98.f, 103.f, 109.f, 115.f, 121.f, + -126.f, -120.f, -114.f, -108.f, -103.f, -97.f, -92.f, -87.f, -83.f, -78.f, -73.f, -69.f, -65.f, -61.f, -57.f, -53.f, + -50.f, -46.f, -43.f, -39.f, -36.f, -33.f, -30.f, -27.f, -24.f, -21.f, -18.f, -15.f, -12.f, -10.f, -7.f, -4.f, + -1.f, 1.f, 4.f, 7.f, 10.f, 13.f, 15.f, 18.f, 21.f, 24.f, 28.f, 31.f, 34.f, 37.f, 41.f, 45.f, + 48.f, 52.f, 56.f, 60.f, 64.f, 69.f, 73.f, 78.f, 83.f, 88.f, 93.f, 99.f, 104.f, 110.f, 116.f, 122.f, +}; + +constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f }; + +constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f }; +constexpr constant static half kvalues_iq3k_h[16] = { -63.h, -40.h, -23.h, -10.h, 1.h, 13.h, 28.h, 47.h, -59.h, -36.h, -19.h, -6.h, 5.h, 17.h, 32.h, 51.h }; + kernel void kernel_cpy_f32_iq4_nl( device const float * src0, device void * dst, @@ -3251,6 +3340,129 @@ kernel void kernel_mul_mv_q2_K_f32( kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } +void kernel_mul_mv_iq2_tn_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_iq2_tn) * nb / 2; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float sumy = 0.f; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy += yl[i+16]; + yl[i+24] = y4[i+96]; sumy += yl[i+24]; + } + + device const half * dh = &x[ib].d; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy); + + qs += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq2_tn_f32")]] +kernel void kernel_mul_mv_iq2_tn_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_tn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, @@ -4703,7 +4915,56 @@ void kernel_mul_mv_iq1_m_f32_impl( } } -void kernel_mul_mv_iq4_nl_f32_impl( +static inline float iq1bn_fp8_to_float(uint8_t fp8) { + typedef union { float f; uint32_t i; } scale_t; + scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18); + return s.f; +} + +//static constant int8_t iq1bn_values[256*5] = { +// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0, +// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1, +// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1, +// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1, +// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1, +// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1, +// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0, +// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1, +// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1, +// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1, +// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0, +// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1, +// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1, +// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1, +// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0, +// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1, +// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1, +// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0, +// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1, +// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0, +// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1, +// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0, +// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0, +// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0, +// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0, +// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0, +// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1, +// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0, +// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0, +// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1, +// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1, +// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0, +// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1, +// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, +// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1, +// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1, +// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1, +// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, +// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1, +// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, +//}; + +void kernel_mul_mv_iq1_bn_f32_impl( device const void * src0, device const float * src1, device float * dst, @@ -4716,88 +4977,93 @@ void kernel_mul_mv_iq4_nl_f32_impl( int64_t ne1, uint r2, uint r3, - threadgroup int8_t * shared_values_i8, + threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; + const int nb = ne00/QK_IQ1BN; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 + float yl[16]; + float sumf[N_DST]={0.f}; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); + const int nb32 = nb * (QK_IQ1BN / 32); - float4 yl[4]; - float sumf[2]={0.f}, all_sum; + const int ix = tiisg/2; + const int ir = tiisg%2; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * y4 = (device const float *)y + 32 * ix + 16 * ir; uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + const float values[3] = {-1.f, 0.f, 1.f}; - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - device const block_iq4_nl & xb = x[row*nb + ib]; - device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + for (int j = 0; j < 16; ++j) yl[j] = y4[j]; - float4 acc1 = {0.f}, acc2 = {0.f}; + const int ibl = ib32 / (QK_IQ1BN / 32); + const int ib = ib32 % (QK_IQ1BN / 32); + const int i16 = 2*ib + ir; - aux32[0] = q4[0] | (q4[1] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; + device const block_iq1_bn * xr = x + ibl; + device const uint8_t * ql = xr->ql + 3*i16; + device const uint8_t * extra = (device const uint8_t *)&xr->extra; - aux32[0] = q4[2] | (q4[3] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; + for (int row = 0; row < N_DST; row++) { - acc1 += acc2; + float acc = 0; + int i = 0; + for (int k = 0; k < 3; ++k) { + //constant int8_t * vs = iq1bn_values + 5*ql[k]; + //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j]; + uint8_t q = ql[k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + v = 3*v >> 8; //(v + (v >> 1)) >> 7; + acc += yl[i++] * values[v]; + } + } + //constant int8_t * vs = iq1bn_values + 5*extra[0]; + //acc += yl[15] * vs[i16]; + uint8_t v = k_mult[i16]*extra[0]; + v = 3*v >> 8; //(v + (v >> 1)) >> 7; + acc += yl[15] * values[v]; - sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + sumf[row] += acc; + extra += nb*sizeof(block_iq1_bn); + ql += nb*sizeof(block_iq1_bn); } - yb += 16 * QK4_NL; + y4 += 32 * 16; } - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + for (int row = 0; row < N_DST; row += 2) { + half2 r = {(half)sumf[row], (half)sumf[row+1]}; + r = simd_sum(r); + if (tiisg < 2) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg]; } } } -void kernel_mul_mv_iq4_xs_f32_impl( +void kernel_mul_mv_iq2_bn_f32_impl( device const void * src0, device const float * src1, device float * dst, @@ -4810,62 +5076,236 @@ void kernel_mul_mv_iq4_xs_f32_impl( int64_t ne1, uint r2, uint r3, - threadgroup int8_t * shared_values_i8, + threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; + const int nb = ne00/QK_IQ1BN; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const block_iq2_bn * x = (device const block_iq2_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + float yl[16]; + float sumf[N_DST]={0.f}; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); + const int ix = tiisg/4; // 0...7 + const int ir = tiisg%4; // 0...3 - float4 yl[4]; - float sumf[2]={0.f}, all_sum; + device const float * y4 = y + 64 * ix + 4 * ir; - device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + for (int ib = ix; ib < nb; ib += 8) { - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; + float sumy = 0.f; + for (int i = 0; i < 4; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; + yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4]; + yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; + yl[i+12] = y4[i+48]; sumy += yl[i+12]; + } - float4 qf1, qf2; + device const uint8_t * qs = x[ib].qs + 4*ir; - for (int ibl = ix; ibl < nb; ibl += 2) { + for (int row = 0; row < N_DST; row++) { - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + float4 acc = {0.f}; + for (int j = 0; j < 4; ++j) { + acc[0] += yl[j+ 0] * (qs[j] & 0x03); + acc[1] += yl[j+ 4] * (qs[j] & 0x0c); + acc[2] += yl[j+ 8] * (qs[j] & 0x30); + acc[3] += yl[j+12] * (qs[j] & 0xc0); + } - for (int row = 0; row < 2; ++row) { + sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy; - device const block_iq4_xs & xb = x[row*nb + ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + qs += nb*sizeof(block_iq2_bn); + } - float4 acc1 = {0.f}, acc2 = {0.f}; + y4 += 64 * 8; + } - aux32[0] = q4[0] & 0x0f0f0f0f; - aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; + for (int row = 0; row < N_DST; row += 2) { + half2 r = {(half)sumf[row], (half)sumf[row+1]}; + r = simd_sum(r); + if (tiisg < 2) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg]; + } + } +} + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; aux32[0] = q4[1] & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; @@ -4874,10 +5314,580 @@ void kernel_mul_mv_iq4_xs_f32_impl( acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; - acc1 += acc2; - - const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; - sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq2_k_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_k * x = (device const block_iq2_k *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + uint32_t aux32[2]; + thread const uint8_t * aux8 = (thread const uint8_t *)aux32; + uint16_t shift[4]; + + for (int ib = ix; ib < nb; ib += 4) { + + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; + yl[i+ 8] = y4[i+32]; + yl[i+16] = y4[i+64]; + yl[i+24] = y4[i+96]; + } + + for (int row = 0; row < N_DST; row++) { + + device const block_iq2_k & xb = x[row*nb + ib]; + device const uint32_t * q32 = (device const uint32_t *)xb.qs + 8*iq + 2*ir; + device const uint32_t * sc = (device const uint32_t *)xb.scales; + + const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1; + thread const int8_t * s8 = (thread const int8_t *)&scales32; + uint16_t extra = xb.extra >> (8*iq + is); + + shift[0] = (extra << 2) & 4; + shift[1] = (extra << 1) & 4; + shift[2] = (extra >> 0) & 4; + shift[3] = (extra >> 1) & 4; + + float4 acc = {0.f}; + for (int l = 0; l < 4; ++l) { + constant float * values = kvalues_iq2k_f + shift[l]; + aux32[0] = (q32[0] >> 2*l) & 0x03030303; + aux32[1] = (q32[1] >> 2*l) & 0x03030303; + for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; + } + sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) + acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15)); + + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq2_k_f32")]] +kernel void kernel_mul_mv_iq2_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_k_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + uint32_t vl[2], vh[2]; + uint32_t aux32[2]; + thread const uint8_t * aux8 = (thread const uint8_t *)aux32; + uint16_t shift[4]; + + for (int ib = ix; ib < nb; ib += 4) { + + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; + yl[i+ 8] = y4[i+32]; + yl[i+16] = y4[i+64]; + yl[i+24] = y4[i+96]; + } + + for (int row = 0; row < N_DST; row++) { + + device const block_iq3_k & xb = x[row*nb + ib]; + device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir; + device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir; + device const uint32_t * sc = (device const uint32_t *)xb.scales_l; + + const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1; + thread const int8_t * s8 = (thread const int8_t *)&scales32; + uint16_t extra = xb.extra >> (8*iq + is); + uint16_t signs = xb.scales_h >> (8*iq + is); + + shift[0] = (extra << 3) & 8; + shift[1] = (extra << 2) & 8; + shift[2] = (extra << 1) & 8; + shift[3] = (extra << 0) & 8; + + vl[0] = ql16[0] | ql16[1] << 16; + vl[1] = ql16[2] | ql16[3] << 16; + vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2; + vh[1] = ((qh16[2] | (qh16[3] << 16)) << 4*(1-iq)) >> 2; + + float4 acc = {0.f}; + for (int l = 0; l < 4; ++l) { + constant float * values = kvalues_iq3k_f + shift[l]; + aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404); + aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404); + for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; + vl[0] >>= 2; vl[1] >>= 2; + vh[0] >>= 1; vh[1] >>= 1; + } + + sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) + + acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3])); + + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq3_k_f32")]] +kernel void kernel_mul_mv_iq3_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq4_k_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_k * x = (device const block_iq4_k *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4k_f[tiisg]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + //float2 sumy; + //sumy[0] = -4.f*(yl[0][0] + yl[0][1] + yl[0][2] + yl[0][3] + yl[2][0] + yl[2][1] + yl[2][2] + yl[2][3]); + //sumy[1] = -4.f*(yl[1][0] + yl[1][1] + yl[1][2] + yl[1][3] + yl[3][0] + yl[3][1] + yl[3][2] + yl[3][3]); + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_k & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + uint16_t extra = xb.extra >> 2*ib; + threadgroup const float * values1 = shared_values + 16*(extra & 1); + threadgroup const float * values2 = shared_values + 8*(extra & 2); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + const uint8_t h = xb.scales_h[ib/2] >> 4*(ib%2); + const int ls1 = ((xb.scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((xb.scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3])); + //uint16_t extra = xb.extra >> 2*ib; + //sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3] + (extra & 1 ? sumy[0] : 0)) + + // ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3] + (extra & 2 ? sumy[1] : 0))); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq5_k_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq5_k * x = (device const block_iq5_k *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib64 = it/4; + const int il64 = it%4; + + shared_values[2*tiisg+0] = kvalues_iq5k_f[2*tiisg+0]; + shared_values[2*tiisg+1] = kvalues_iq5k_f[2*tiisg+1]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib64 * 64 + il64 * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[8]; yl[2] = y4[1]; yl[3] = y4[9]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq5_k & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 32*ib64 + 8*il64); + device const uint32_t * qh = (device const uint32_t *)(xb.qh + 8*il64); + + uint16_t extra = xb.extra >> (4*ib64 + il64/2); + threadgroup const float * values1 = shared_values + 32*(extra & 1); + threadgroup const float * values2 = shared_values + 8*(extra & 4); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + uint32_t h = qh[0] >> 2*ib64; + aux32[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + h = qh[1] >> 2*ib64; + aux32[0] = ((q4[1] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[1] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + const uint8_t sh = xb.scales_h[ib64] >> 2*(il64/2); + const int ls1 = (((xb.scales_l[2*ib64 + 0 + il64/2] >> 4*(il64/2)) & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = (((xb.scales_l[2*ib64 + 1 + il64/2] >> 4*(il64/2)) & 0xf) | ((sh << 0) & 0x30)) - 32; + sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3])); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq6_k_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq6_k * x = (device const block_iq6_k *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib64 = it/4; // 0...3 + const int il64 = it%4; // 0...3 + + shared_values[4*tiisg+0] = kvalues_iq6k_f[4*tiisg+0]; + shared_values[4*tiisg+1] = kvalues_iq6k_f[4*tiisg+1]; + shared_values[4*tiisg+2] = kvalues_iq6k_f[4*tiisg+2]; + shared_values[4*tiisg+3] = kvalues_iq6k_f[4*tiisg+3]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib64 * 64 + il64 * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[8]; yl[2] = y4[1]; yl[3] = y4[9]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq6_k & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 32*ib64 + 8*il64); + device const uint32_t * qh = (device const uint32_t *)(xb.qh + 32*(ib64/2) + 8*il64); + + uint16_t extra = xb.extra >> (4*ib64 + il64/2); + threadgroup const float * values1 = shared_values + 64*(extra & 1); + threadgroup const float * values2 = shared_values + 16*(extra & 4); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + uint32_t h = qh[0] >> 4*(ib64%2); + aux32[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x30303030); + aux32[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | ((h << 2) & 0x30303030); + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + h = qh[1] >> 4*(ib64%2); + aux32[0] = ((q4[1] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x30303030); + aux32[1] = ((q4[1] >> 4) & 0x0f0f0f0f) | ((h << 2) & 0x30303030); + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + const int ls1 = xb.scales[4*ib64 + 0 + il64/2]; + const int ls2 = xb.scales[4*ib64 + 2 + il64/2]; + sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3])); } @@ -4948,6 +5958,62 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq1_bn_f32")]] +kernel void kernel_mul_mv_iq1_bn_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq2_bn_f32")]] +kernel void kernel_mul_mv_iq2_bn_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( device const void * src0, @@ -5006,6 +6072,93 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq4_k_f32")]] +kernel void kernel_mul_mv_iq4_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq5_k_f32")]] +kernel void kernel_mul_mv_iq5_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq5_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq6_k_f32")]] +kernel void kernel_mul_mv_iq6_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq6_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -5122,7 +6275,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg } template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { +void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) { const float d = xb->d; const float min = xb->dmin; device const uint8_t * q = (device const uint8_t *)xb->qs; @@ -5140,6 +6293,21 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg } } +template +void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) { + const half d = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1); + + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const half dl = d * coef; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - d; + } +} + template void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { const half d_all = xb->d; @@ -5405,6 +6573,42 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & } } + +template +void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { + // il is in 0...3 + + constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + + int i = 0; + for (int k = 0; k < 3; ++k) { + uint8_t q = xb->ql[3*il + k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = 3*v >> 8; + //int8_t vs = (v + (v >> 1)) >> 7; + reg[i/4][i%4] = vs - 1; + ++i; + } + } + uint8_t v = k_mult[il]*xb->extra; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + reg[3][3] = vs - 1; +} + +template +void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4 & reg) { + // il is in 0...3 + constexpr float k_scale[4] = {1.f, 0.25f, 0.0625f, 0.015625f}; + constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0}; + const float d = k_scale[il]; + uint8_t mask = k_mask[il]; + + for (int j = 0; j < 16; ++j) { + reg[j/4][j%4] = d * (xb->qs[j] & mask) - 1; + } +} + template void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { device const uint16_t * q4 = (device const uint16_t *)xb->qs; @@ -5440,6 +6644,108 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 } } +template +void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 + device const uint32_t * q32 = (device const uint32_t *)xb->qs + 8*(il/8) + 4*(il&1); + half d = xb->d * (2*((xb->scales[il/2] >> 4*(il&1)) & 0xf) - 15); + + constant int8_t * int_values = iq2nl_values + 4*((xb->extra >> il) & 1); + half4 values = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3] }; + + const int shift = 2*((il%8)/2); + uint32_t aux32; + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q32[i] >> shift) & 0x03030303; + for (int j = 0; j < 4; ++j) reg[i][j] = values[aux8[j]]; + } +} + +template +void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 + device const uint16_t * q16l = (device const uint16_t *)xb->qs + 16*(il/8) + 8*(il&1); + device const uint16_t * q16h = (device const uint16_t *)xb->qh + 8*(il&1); + half d = xb->d * (2*((xb->scales_l[il/2] >> 4*(il&1)) & 0xf) + 1) * (xb->scales_h & (1 << il) ? -1 : 1); + + constant half * values = kvalues_iq3k_h + 8*((xb->extra >> il) & 1); + + const int shift = 2*((il%8)/2); + uint32_t aux32; + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + uint32_t vl = q16l[2*i+0] | (q16l[2*i+1] << 16); + uint32_t vh = q16h[2*i+0] | (q16h[2*i+1] << 16); + aux32 = ((vl >> shift) & 0x03030303) | (((vh >> ((il/2)%8)) << 2) & 0x04040404); + for (int j = 0; j < 4; ++j) reg[i][j] = d * values[aux8[j]]; + } +} + +template +void dequantize_iq4_k(device const block_iq4_k * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + const int l = il%2; + // l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + constant float * values = kvalues_iq4k_f + 16*((xb->extra >> il) & 1); + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*l) & 0x0f0f0f0f; + reg[i][0] = d * values[q8[0]]; + reg[i][1] = d * values[q8[1]]; + reg[i][2] = d * values[q8[2]]; + reg[i][3] = d * values[q8[3]]; + } +} + +template +void dequantize_iq5_k(device const block_iq5_k * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + const int l = il%2; + // l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 8*(ib32/2) + 4*l; + device const uint32_t * qh = (device const uint32_t *)xb->qh + 4*l; + const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + constant float * values = kvalues_iq5k_f + 32*((xb->extra >> il) & 1); + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[i] >> 4*(ib32%2)) & 0x0f0f0f0f) | (((qh[i] >> ib32) & 0x01010101) << 4); + reg[i][0] = d * values[q8[0]]; + reg[i][1] = d * values[q8[1]]; + reg[i][2] = d * values[q8[2]]; + reg[i][3] = d * values[q8[3]]; + } +} + +template +void dequantize_iq6_k(device const block_iq6_k * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + const int l = il%2; + // l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 8*(ib32/2) + 4*l; + device const uint32_t * qh = (device const uint32_t *)xb->qh + 8*(ib32/4) + 4*l; + const float d = (float)xb->d * xb->scales[2*ib32+l]; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + constant float * values = kvalues_iq6k_f + 64*((xb->extra >> il) & 1); + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[i] >> 4*(ib32%2)) & 0x0f0f0f0f) | (((qh[i] >> 2*(ib32%4)) & 0x03030303) << 4); + reg[i][0] = d * values[q8[0]]; + reg[i][1] = d * values[q8[1]]; + reg[i][2] = d * values[q8[2]]; + reg[i][3] = d * values[q8[3]]; + } +} + template kernel void kernel_get_rows_q( device const void * src0, @@ -5888,6 +7194,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -5901,6 +7208,13 @@ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_k")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_k")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q; // // matrix-matrix multiplication @@ -5916,6 +7230,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -5929,6 +7244,13 @@ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -5944,6 +7266,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; @@ -5955,8 +7278,15 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -6153,12 +7483,15 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; @@ -6166,3 +7499,8 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq5_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq6_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 334992ca14d71..23a3dae7cc1ae 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1,8 +1,18 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #define GGML_COMMON_IMPL_C #include "ggml-common.h" #include "ggml-quants.h" #include "ggml-impl.h" +#if GGML_USE_IQK_MULMAT +#include "iqk/iqk_mul_mat.h" +#endif #include @@ -904,8 +914,15 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) block_q8_0 * restrict y = vy; +#if GGML_USE_IQK_MULMAT + const int nb4 = 4*(nb/4); +#else + const int nb4 = -1; +#endif #if defined(__ARM_NEON) + block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; @@ -922,16 +939,27 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = GGML_FP32_TO_FP16(d); + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } for (int j = 0; j < 8; j++) { const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + if (i < nb4) { + y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } } } #elif defined(__wasm_simd128__) @@ -968,7 +996,14 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) } } #elif defined(__AVX2__) || defined(__AVX__) + block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; +#ifdef __AVX2__ + const bool pack = true; +#else + const bool pack = false; +#endif for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); __m256 v1 = _mm256_loadu_ps( x + 8 ); @@ -990,7 +1025,11 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) // Quantize these floats const float d = maxScalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); + if (pack && i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1025,7 +1064,11 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - _mm256_storeu_si256((__m256i *)y[i].qs, i0); + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } #else // Since we don't have in AVX some necessary functions, // we split the registers in half and call AVX2 analogs from SSE @@ -1224,8 +1267,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) block_q8_1 * restrict y = vy; +#if GGML_USE_IQK_MULMAT + const int nb4 = 4*(nb/4); +#else + const int nb4 = -1; +#endif #if defined(__ARM_NEON) + block_q8_1_x4 * restrict y4 = vy; for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; @@ -1242,7 +1292,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = GGML_FP32_TO_FP16(d); + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } int32x4_t accv = vdupq_n_s32(0); @@ -1250,15 +1304,26 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + if (i < nb4) { + y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } accv = vaddq_s32(accv, vi); } - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } else { + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { @@ -1304,7 +1369,14 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) wasm_i32x4_extract_lane(accv, 3))); } #elif defined(__AVX2__) || defined(__AVX__) + block_q8_1_x4 * restrict y4 = vy; +#ifdef __AVX2__ + const bool pack = true; +#else + const bool pack = false; +#endif for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); __m256 v1 = _mm256_loadu_ps( x + 8 ); @@ -1326,7 +1398,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) // Quantize these floats const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); + if (pack && i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1350,7 +1426,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #if defined(__AVX2__) // Compute the sum of the quants and set y[i].s - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } else { + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -1364,7 +1444,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - _mm256_storeu_si256((__m256i *)y[i].qs, i0); + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } #else // Since we don't have in AVX some necessary functions, // we split the registers in half and call AVX2 analogs from SSE @@ -1965,7 +2049,52 @@ void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, in const float q4scale = 15.f; + // Detect TriNet + { + int n = k; + float max = 0; + for (int j = 0; j < n; ++j) { + float ax = fabsf(x[j]); + max = MAX(max, ax); + } + float mse0 = 0, mse = 0; + for (int j = 0; j < n; ++j) { + int l = x[j] < -0.5f*max ? -1 : x[j] < 0.5f*max ? 0 : 1; + mse0 += x[j]*x[j]; + float diff = x[j] - max*l; + mse += diff*diff; + } + if (mse < 0.1f*mse0) { + // yes, most likely trinet + for (int ibl = 0; ibl < nb; ++ibl) { + y[ibl].d = GGML_FP32_TO_FP16(max); + y[ibl].dmin = GGML_FP32_TO_FP16(max); + for (int ib = 0; ib < QK_K/16; ++ib) y[ibl].scales[ib] = 1 | (1 << 4); + const float * xb = x + QK_K * ibl; + for (int j = 0; j < QK_K; ++j) { + L[j] = xb[j] < -0.5f*max ? 0 : xb[j] < 0.5f*max ? 1 : 2; + } + uint8_t * qs = y[ibl].qs; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + qs[l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + qs += 32; + } + } + return; + } + } + for (int i = 0; i < nb; i++) { + //{ + // float max = x[0], min = x[0]; + // for (int j = 1; j < 256; ++j) { + // max = MAX(x[j], max); + // min = MIN(x[j], min); + // } + // printf("%s: max = %g, min = %g\n", __func__, (double)max, (double)min); + //} float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/16; ++j) { @@ -4171,6 +4300,11 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r } void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + return; + } +#endif const int qk = QK8_0; const int nb = n / qk; @@ -4653,6 +4787,11 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) { + return; + } +#endif const int qk = QK8_1; const int nb = n / qk; @@ -4940,6 +5079,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r } void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + return; + } +#endif const int qk = QK8_0; const int nb = n / qk; @@ -5295,6 +5439,11 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) { + return; + } +#endif const int qk = QK8_1; const int nb = n / qk; @@ -5669,6 +5818,11 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + return; + } +#endif const int qk = QK8_0; const int nb = n / qk; @@ -12501,6 +12655,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void } void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + return; + } +#endif assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -15760,6 +15919,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_IQ2_K: break; + case GGML_TYPE_IQ3_K: break; + case GGML_TYPE_IQ4_K: break; + case GGML_TYPE_IQ5_K: break; + case GGML_TYPE_IQ6_K: break; + case GGML_TYPE_IQ2_TN: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { @@ -15775,6 +15940,8 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_I64: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: // nothing to validate break; default: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index b3979f560d6ab..6721eea895e22 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #define GGML_COMMON_DECL_C @@ -27,6 +34,7 @@ void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_REST void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); @@ -36,6 +44,8 @@ void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGM void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_bn_ref (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k); void quantize_row_q1_3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -52,6 +62,7 @@ void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -61,6 +72,8 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization void dequantize_row_q1_3(const block_q1_3 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -91,6 +104,8 @@ void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_bn (const block_iq1_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product void ggml_vec_dot_q1_3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -119,6 +134,8 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -130,6 +147,8 @@ size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6bc936631e8d3..a5e181e3df127 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1,3 +1,9 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC @@ -5,6 +11,10 @@ #include "ggml-quants.h" #include "ggml.h" #include "ggml-aarch64.h" +#include "iqk/iqk_quantize.h" +#if GGML_USE_IQK_MULMAT +#include "iqk/iqk_mul_mat.h" +#endif #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -971,6 +981,42 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ1_BN] = { + .type_name = "iq1_bn", + .blck_size = QK_IQ1BN, + .type_size = sizeof(block_iq1_bn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_bn, + .from_float = quantize_row_iq1_bn, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_bn_ref, + .vec_dot = ggml_vec_dot_iq1_bn_q8_K64, + .vec_dot_type = GGML_TYPE_Q8_K64, + .nrows = 1, + }, + [GGML_TYPE_IQ2_BN] = { + .type_name = "iq2_bn", + .blck_size = QK_IQ1BN, + .type_size = sizeof(block_iq2_bn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_bn, + .from_float = quantize_row_iq2_bn, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_bn_ref, + .vec_dot = ggml_vec_dot_iq2_bn_q8_K64, + .vec_dot_type = GGML_TYPE_Q8_K64, + .nrows = 1, + }, + [GGML_TYPE_IQ2_TN] = { + .type_name = "iq2_tn", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_tn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_tn, + .from_float = quantize_row_iq2_tn, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_tn_ref, + .vec_dot = vec_dot_iq2_tn_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", .blck_size = QK4_NL, @@ -1002,6 +1048,13 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .from_float = quantize_row_q8_K, }, + [GGML_TYPE_Q8_K64] = { + .type_name = "q8_K64", + .blck_size = 64, + .type_size = sizeof(block_q8_K64), + .is_quantized = true, + .from_float = quantize_row_q8_K64, + }, [GGML_TYPE_BF16] = { .type_name = "bf16", .blck_size = 1, @@ -1086,6 +1139,66 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ2_K] = { + .type_name = "iq2_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_k, + .from_float = quantize_row_iq2_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_k_ref, + .vec_dot = vec_dot_iq2_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ3_K] = { + .type_name = "iq3_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_k, + .from_float = quantize_row_iq3_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_k_ref, + .vec_dot = vec_dot_iq3_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ4_K] = { + .type_name = "iq4_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_k, + .from_float = quantize_row_iq4_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_k_ref, + .vec_dot = vec_dot_iq4_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ5_K] = { + .type_name = "iq5_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq5_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq5_k, + .from_float = quantize_row_iq5_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq5_k_ref, + .vec_dot = vec_dot_iq5_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ6_K] = { + .type_name = "iq6_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq6_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq6_k, + .from_float = quantize_row_iq6_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq6_k_ref, + .vec_dot = vec_dot_iq6_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, }; // For internal test use @@ -2385,7 +2498,7 @@ inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } -inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +//inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } @@ -2799,6 +2912,13 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) { return vdivq_f32(x, one_plus_exp_neg_x); } +inline static float32x4_t ggml_v_tanh(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); + const float32x4_t exp_two_x = ggml_v_expf(two_x); + return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); +} + #elif defined(__AVX512F__) && defined(__AVX512DQ__) // adapted from arm limited optimized routine @@ -2842,6 +2962,12 @@ inline static __m512 ggml_v_silu(__m512 x) { return _mm512_div_ps(x, one_plus_exp_neg_x); } +inline static __m512 ggml_v_tanh(__m512 x) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f))); + return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); +} + #elif defined(__AVX2__) && defined(__FMA__) // adapted from arm limited optimized routine @@ -2897,6 +3023,12 @@ inline static __m256 ggml_v_silu(__m256 x) { return _mm256_div_ps(x, one_plus_exp_neg_x); } +inline static __m256 ggml_v_tanh(__m256 x) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f))); + return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); +} + #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON // adapted from arm limited optimized routine @@ -2943,6 +3075,12 @@ inline static __m128 ggml_v_silu(__m128 x) { return _mm_div_ps(x, one_plus_exp_neg_x); } +inline static __m128 ggml_v_tanh(__m128 x) { + const __m128 one = _mm_set1_ps(1.0f); + const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f))); + return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one)); +} + #endif // __ARM_NEON / __AVX2__ / __SSE2__ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { @@ -2978,6 +3116,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, ggml_v_tanh(_mm512_loadu_ps(x + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, ggml_v_tanh(_mm256_loadu_ps(x + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, ggml_v_tanh(_mm_loadu_ps(x + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, ggml_v_tanh(vld1q_f32(x + i))); + } +#endif + for (; i < n; ++i) { + y[i] = tanhf(x[i]); + } +} + static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; @@ -3390,6 +3552,8 @@ static void ggml_barrier(struct ggml_compute_state_shared * shared) { } #if defined(__SSE3__) _mm_pause(); + #elif defined __ARM_NEON + __asm__ __volatile__("isb\n"); #endif } sched_yield(); @@ -3687,8 +3851,16 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; + case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break; + case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; + case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; + case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; + case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; + case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; + case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break; + case GGML_FTYPE_MOSTLY_IQ6_K: wtype = GGML_TYPE_IQ6_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; @@ -9960,8 +10132,16 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10338,8 +10518,16 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10466,8 +10654,16 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10575,6 +10771,7 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; + #if defined(GGML_USE_CLBLAST) if (src1->backend == GGML_BACKEND_TYPE_GPU) { // TODO: OpenCL kernel support full broadcast @@ -10586,6 +10783,23 @@ static void ggml_compute_forward_mul_f32( } #endif + if (ggml_nelements(dst->src[1]) == 1 && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst) && + dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + int64_t nelements = ggml_nelements(dst->src[0]); + int64_t n_per_thread = (nelements + nth - 1)/nth; + n_per_thread = MAX(1024, n_per_thread); + int64_t start = n_per_thread*ith; + if (start >= nelements) return; + int64_t end = MIN(nelements, start + n_per_thread); + const float * src = (const float *)dst->src[0]->data + start; + float * res = (float *)dst->data + start; + if (res != src) { + memcpy(res, src, (end - start)*sizeof(float)); + } + ggml_vec_scale_f32(end - start, res, *(const float *)dst->src[1]->data); + return; + } + const int64_t nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -11637,9 +11851,8 @@ static void ggml_compute_forward_tanh_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } + const int ith = params->ith; + const int nth = params->nth; assert(ggml_is_contiguous_1(src0)); assert(ggml_is_contiguous_1(dst)); @@ -11648,7 +11861,7 @@ static void ggml_compute_forward_tanh_f32( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - for (int i = 0; i < n; i++) { + for (int i = ith; i < n; i += nth) { ggml_vec_tanh_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), (float *) ((char *) src0->data + i*(src0->nb[1]))); @@ -12729,9 +12942,43 @@ static void ggml_compute_forward_mul_mat( } #endif #if GGML_USE_LLAMAFILE + +#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE // broadcast factors const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; +#endif + +#if GGML_USE_IQK_MULMAT + if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) { + int counter = 0; + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + if (counter++ % nth == ith) { + if (!iqk_mul_mat(ne01, ne11, ne00, + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + 0, 1)) goto IQK_MulMat_Not_Available1; + } + } + } + return; + } + if (dst->type == GGML_TYPE_F32) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!iqk_mul_mat(ne01, ne11, ne00, + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + ith, nth)) goto IQK_MulMat_Not_Available1; + return; + } +IQK_MulMat_Not_Available1:; +#endif + +#if GGML_USE_LLAMAFILE const bool src1_cont = ggml_is_contiguous(src1); @@ -12792,9 +13039,27 @@ UseGgmlGemm1:; ggml_barrier(params->shared); + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + +#if GGML_USE_IQK_MULMAT + if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!iqk_mul_mat(ne01, ne11, ne00, + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), + (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + ith, nth)) goto IQK_MulMat_Not_Available2; + return; + } +IQK_MulMat_Not_Available2:; +#endif + + ggml_barrier(params->shared); + #if GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { - const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); for (int64_t i13 = 0; i13 < ne13; i13++) @@ -13010,6 +13275,46 @@ static void ggml_compute_forward_mul_mat_id( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows + // +#if GGML_USE_IQK_MULMAT + if (ne13 == 1 && dst->type == GGML_TYPE_F32) { + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type), + (float *)dst->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available; + continue; + } +IQK_MulMat_Not_Available:; +#endif + + if (((ggml_n_dims(src0) - 1) == 2) && gemv) { + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; + src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12)); + + gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); + } + continue; + } if (((ggml_n_dims(src0) - 1) == 2) && gemv) { int64_t src0_cur_start = (ith * ne01) / nth; @@ -13336,8 +13641,16 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13524,8 +13837,16 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13788,8 +14109,16 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -14381,11 +14710,20 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: + case GGML_TYPE_Q8_K64: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -18951,7 +19289,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: @@ -18967,6 +19304,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_UNARY_OP_TANH: + { + n_tasks = MIN(ggml_nrows(node), n_threads); + } break; default: GGML_ABORT("fatal error"); } @@ -18992,7 +19333,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { //n_tasks = n_threads; n_tasks = 1; } break; - case GGML_OP_SCALE: case GGML_OP_SET: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -19016,6 +19356,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = 1; //TODO } break; + case GGML_OP_SCALE: case GGML_OP_SOFT_MAX: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); @@ -21101,8 +21442,16 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ6_K: result = quantize_iq6_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -21687,7 +22036,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ctx->infos[i].ne[3], }; - struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); + int n_dims = ctx->infos[i].n_dims; + if (n_dims == 0 || n_dims > 4) { + n_dims = 4; + for (; n_dims > 1; --n_dims) if (ne[n_dims-1] > 1) break; + } + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, n_dims, ne); ok = ok && cur != NULL; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp new file mode 100644 index 0000000000000..32ddb3ff9e710 --- /dev/null +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -0,0 +1,5929 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp fenc=utf-8 :vi +// +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#if defined IQK_IMPLEMENT +#undef IQK_IMPLEMENT +#endif + +#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD +#define IQK_IMPLEMENT +#endif + +#include +#include + +#if defined IQK_IMPLEMENT + +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "iqk_mul_mat.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +// clang-format off + +// This matrix - vector and matrix - matrix multiplication implementation +// for k-quants, i-quants, and legacy quants, makes prompt processing +// 150-350% faster (depending on quantization type) compared to mainline llama.cpp. +// It is AVX2 and ARM_NEON only for now. +// There are also implementations for fp16/32 x fp16/32 matrix multiplications +// on AVX2 and fp16 x fp16 on ARM_NEON. +// +// Main idea is that unpacking the quants and the block scales to +// be ready for dot products with the corresponding Q8_X quants +// takes time. Hence, if we are performing a QX x Q8_X matrix matrix +// multiplication (as needed for prompt processing), we can get +// a significant speedup by reusing the unpacked QX quants and scales +// for multiplication with several Q8_X columns. +// +// For fp16/fp32 matri multiplications tiling is used to improve +// performance. + +#include +#include + +#ifdef _MSC_VER +#define IQK_NOINLINE __declspec(noinline) +#define IQK_ALWAYS_INLINE inline +#else +#define IQK_NOINLINE __attribute__((__noinline__)) +#define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) +#endif + +namespace { + +typedef struct { + int32_t i1; + int32_t i2; +} mmid_row_mapping; + +struct DataInfo { + float * s; + const char * cy; + size_t bs; + size_t by; + int cur_y = 0; + int ne11; + const mmid_row_mapping * row_mapping = nullptr; + size_t bs2 = 0; + + inline const char * src1_row(int iy) const { + if (!row_mapping) return cy + (cur_y + iy)*by; + int i11 = row_mapping[cur_y + iy].i1 % ne11; + int i12 = row_mapping[cur_y + iy].i2; + return cy + (i11 + i12*ne11)*by; + } + + inline void store(int ix, int iy, float result) const { + *(dst_row(iy) + ix) = result; + } + inline float * dst_row(int iy) const { + if (!row_mapping) return s + (cur_y + iy)*bs; + int i12 = row_mapping[cur_y + iy].i2; + int i1 = row_mapping[cur_y + iy].i1; + int i2 = i12; + return s + i1*bs + i2*bs2; + } +}; + +typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); + +struct MulMat { + std::array funcs = {}; + inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { +#ifdef __aarch64__ + constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) +#else + constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) +#endif + int ny = funcs.size(); + while (!funcs[ny-1] && ny > 0) --ny; + int n_step = (nrc_y - info.cur_y)/ny; + if (n_step > 0) { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); + this_info.cur_y += ny; + } + } + info.cur_y += ny * n_step; + } + int n_left = nrc_y - info.cur_y; + if (n_left > 0) { + funcs[n_left-1](n, vx, bx, info, nrc_x); + } + } + static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); +private: + template static void set_functions(MulMat& m); +}; + +} + +bool iqk_mul_mat(long Nx, long Ny, long ne00, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, int ith, int nth) { + + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + + auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); + auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); + + auto nrc_x = (Nx + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + + DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; + + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + + return true; +} + +bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { + const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; + assert(row_mapping != nullptr); + + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); + auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); + int nrc_x = (Nx + nth - 1)/nth; + int first_x = ith*nrc_x; + if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), + row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + return true; +} + +namespace { + +inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { + const uint16_t * scales = (const uint16_t *)scales8; + const uint32_t a0 = scales[0] | (scales[1] << 16); + const uint32_t a1 = scales[2] | (scales[3] << 16); + const uint32_t a2 = scales[4] | (scales[5] << 16); + aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030); + aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030); + aux32[2] = a1 & 0x3f3f3f3f; + aux32[0] = a0 & 0x3f3f3f3f; +} + +const uint64_t keven_signs[128] = { + 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, + 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, + 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, + 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, + 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, + 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, + 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, + 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, + 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, + 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, + 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, + 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, + 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, + 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, + 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, + 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, + 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, + 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, + 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, + 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, + 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, + 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, + 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, + 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, + 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, + 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, + 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, + 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, + 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, + 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, + 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, + 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, +}; + +} + +#if defined __x86_64__ + +#if defined HAVE_FANCY_SIMD + #undef HAVE_FANCY_SIMD +#endif +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + #define HAVE_FANCY_SIMD +#endif + +namespace { + +inline float hsum_float_4(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +inline float hsum_float_8(__m256 x) { + return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); +} +inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +template struct Q8 { + + constexpr static int nrc_y = nrc; + + Q8(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); + } + +#ifdef HAVE_FANCY_SIMD + inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); } +#endif + inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } + inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); } + inline float scale(int iy, int i) const { return y[iy][i].d; } + + const block_q8 * y[nrc_y]; +}; + +struct Scales8KBase { + template + inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { + const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0])); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i q8s = q8.load_bsums(iy, i); + const __m256i prod = _mm256_madd_epi16(mins, q8s); + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); + } + } + inline __m256i shuffle(__m128i mins) const { + return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0])); + } + const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100), + _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)}; +}; + +// Handles q4_K and q5_K scales/mins +struct Scales8K { + template + inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { + make_q4_scales(data, utmp); + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1); + accum_mins(mins128, q8, i, c, accd); + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + return MM256_SET_M128I(sc128, sc128); + } +#ifdef HAVE_FANCY_SIMD + template + inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { + auto scales = process_mins_and_scales(data, c, i, q8, accd); + return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1); + } +#endif + template + inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { + base.accum_mins(mins128, q8, i, c, accd); + } +#ifdef HAVE_FANCY_SIMD + const __m512i shuffles512[2] = { + _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302, + 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100), + _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, + 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) + }; +#endif + Scales8KBase base; + + uint32_t utmp[4]; +}; + +template +inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } +} +inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) { + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + scales[0] = MM256_SET_M128I(l_scales, l_scales); + scales[1] = MM256_SET_M128I(h_scales, h_scales); +} + +struct ScaleQ3 { + inline __m128i make_scales(const uint16_t * s8) const { + const uint16_t * scales16 = (const uint16_t *)s8; + uint32_t aux0 = scales16[0] | (scales16[1] << 16); + uint32_t aux1 = scales16[2] | (scales16[3] << 16); + uint32_t aux2 = scales16[4] | (scales16[5] << 16); + __m128i scales128 = _mm_set_epi32( + ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030), + ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030), + (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030), + (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030)); + return _mm_add_epi8(scales128, m32); + } + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct ScaleIQ4XS { + inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) { + uint32_t tmp32 = scales_h | (scales_h << 14); + const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4); + const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask); + return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32); + } + const __m128i hshift = _mm_set_epi32(12, 8, 4, 0); + const __m128i lshift = _mm_set_epi32(4, 0, 4, 0); + const __m128i hmask = _mm_set1_epi16(0x03); + const __m128i lmask = _mm_set1_epi8(0xf); + const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400); + const __m128i m32 = _mm_set1_epi16(-32); +}; + +template +struct BaseDequantizer { + BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} + inline void new_row(int ix) { + x = (const Block *)((const char *)vx + bx*ix); + } + + const void * vx; + const size_t bx; + const Block * x; + + float d; +}; + +inline __m256i get_scale_shuffle_8(int i) { + return _mm256_set1_epi16((2*i) | ((2*i+1) << 8)); +} + +inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) { + scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0)); + scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1)); + scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2)); + scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3)); +} + +inline __m256i get_scale_shuffle_16(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} + +inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { + scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0)); + scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1)); + scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2)); + scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3)); +} + +template +inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { + if (j == 0) { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); + } +#else + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); + const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); + const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); + const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); + sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); + } +#endif + } else { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); + } +#else + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); + const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); + const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); + const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); + } +#endif + } +} + +struct SignHelper { + inline __m256i make_signs(uint32_t sign_bits) const { + auto aux256 = _mm256_set1_epi32(sign_bits); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2); + return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone); + } +// inline __m256i make_signs(const uint16_t * sign_bits) const { +//#ifdef HAVE_FANCY_SIMD +//#else +// return make_signs(sign_bits[0] | (sign_bits[1] << 16)); +//#endif +// } + inline __m256i sign_value(const uint16_t * sign_bits, const __m256i& value) const { +#ifdef HAVE_FANCY_SIMD + const __mmask32 * mask = (const __mmask32 *)sign_bits; + return _mm256_mask_sub_epi8(value, mask[0], _mm256_setzero_si256(), value); +#else + return _mm256_sign_epi8(value, make_signs(sign_bits[0] | (sign_bits[1] << 16))); +#endif + } + inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const { +#ifdef HAVE_FANCY_SIMD + const __mmask32 * mask = (const __mmask32 *)sign_bits; + values[0] = _mm256_mask_sub_epi8(values[0], mask[0], _mm256_setzero_si256(), values[0]); + values[1] = _mm256_mask_sub_epi8(values[1], mask[1], _mm256_setzero_si256(), values[1]); + values[2] = _mm256_mask_sub_epi8(values[2], mask[2], _mm256_setzero_si256(), values[2]); + values[3] = _mm256_mask_sub_epi8(values[3], mask[3], _mm256_setzero_si256(), values[3]); +#else + auto s128 = _mm_loadu_si128((const __m128i *)sign_bits); + auto s256 = MM256_SET_M128I(s128, s128); + __m256i aux256; + auto shuffle = mask1; + auto step = _mm256_set1_epi8(4); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[0] = _mm256_sign_epi8(values[0], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[1] = _mm256_sign_epi8(values[1], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[2] = _mm256_sign_epi8(values[2], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[3] = _mm256_sign_epi8(values[3], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); +#endif + } + const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull); + const __m256i mone = _mm256_set1_epi8(1); +}; + +struct SimpleBits { + __m256i values[4]; +}; + +#ifdef HAVE_FANCY_SIMD +//====================================== Zen4 ================================================== + +struct BlockPermuter { + const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); +}; + +struct Q4Bits { + inline void prepare(const uint8_t * q4) { + auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); + auto tmp1 = _mm512_and_si512(q4bits, ml); + auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); + values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); + values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); + q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); + tmp1 = _mm512_and_si512(q4bits, ml); + tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); + values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); + values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); + } + inline void prepare64(const uint8_t * q4) { + auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); + values[0] = _mm512_and_si512(q4bits, ml); + values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); + q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); + values[2] = _mm512_and_si512(q4bits, ml); + values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); + } + __m512i values[4]; + const __m512i ml = _mm512_set1_epi8(0xf); + BlockPermuter perm; +}; + +struct Q2Bits { + inline void prepare(const uint8_t * q2) { + + auto q2bits = _mm512_loadu_si512((const __m512i*)q2); + auto tmp = _mm512_srli_epi16(q2bits, 2); + + values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp); + values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp); + values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml); + values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml); + values[0] = _mm512_and_si512(values[0], ml); + values[2] = _mm512_and_si512(values[2], ml); + } + __m512i values[4]; + const __m512i ml = _mm512_set1_epi8(0x03); + BlockPermuter perm; +}; + +struct DequantizerQ4K final : public BaseDequantizer { + DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); + scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + } + + Q4Bits bits; + Scales8K s8k; +}; + +__m512i load_iq4nl_values_512() { + static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); +} + + +struct DequantizerIQ4XS final : public BaseDequantizer { + DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h); + s8k.accum_mins(scales128, q8, i, -128.f*d, accd); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + } + inline void prepare(const uint8_t * q4) { + bits.prepare64(q4); + // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 + // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + } + + Q4Bits bits; + Scales8K s8k; + ScaleIQ4XS siq4; + const __m512i values; + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); +}; + +struct HighBit5 { + inline void apply(const uint8_t * h, Q4Bits& bits) { + auto hbits256 = _mm256_loadu_si256((const __m256i *)h); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1); + bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh)); + bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh)); + bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh)); + bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh)); + } + const __m512i mh = _mm512_set1_epi8(0x10); +}; + +struct HighBit3 { + inline void apply(const uint8_t * h, Q2Bits& bits) { + auto hbits256 = _mm256_loadu_si256((const __m256i *)h); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1); + bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh)); + bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh)); + bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh)); + bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh)); + } + const __m512i mh = _mm512_set1_epi8(0x04); +}; + +struct DequantizerQ5K final : public BaseDequantizer { + DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + hbits.apply(x[i].qh, bits); + auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); + scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + } + + Q4Bits bits; + HighBit5 hbits; + Scales8K s8k; +}; + +struct Scale16 { + inline void make_scales(const __m128i& scales8, __m512i * scales) const { + auto all_scales8 = MM256_SET_M128I(scales8, scales8); + auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1); + auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2); + scales[0] = _mm512_cvtepi8_epi16(scales1); + scales[1] = _mm512_cvtepi8_epi16(scales2); + } + template + inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8, + const Q8& q8, __m256 * accm, __m512i * scales) const { + process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm); + make_scales(scales8, scales); + } + const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202, + 0x05050505, 0x01010101, 0x04040404, 0x00000000); + const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a, + 0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808); +}; + +struct DequantizerQ2K final : public BaseDequantizer { + DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales); + } + + Q2Bits bits; + Scale16 sc16; + const __m128i m4 = _mm_set1_epi8(0xf); + +}; + +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + } + Q2Bits bits; +}; + +struct DequantizerQ3K final : public BaseDequantizer { + DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + hbits.apply(x[i].hmask, bits); + auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales); + sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales); + } + + Q2Bits bits; + HighBit3 hbits; + ScaleQ3 sc3; + Scale16 sc16; + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerQ6K final : public BaseDequantizer { + DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare64(x[i].ql); + add_high_bits(x[i].qh, bits); + auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales); + sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales); + } + + inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const { + auto hbits = _mm512_loadu_si512((const __m512i *)qh); + auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh); + auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh); + bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2)); + bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2)); + tmp1 = _mm512_and_si512(hbits, mh); + tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh); + bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2)); + bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2)); + } + + Q4Bits bits; + HighBit3 hbits; + Scale16 sc16; + + const __m512i mh = _mm512_set1_epi8(0x30); + +}; + +struct IQXKScales { + IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {} + template + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)); + scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift)); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + scales16 = MM256_SET_M128I(scales8, scales8); + scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1)); + scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2)); + } + const __m256i eshift; + const __m256i min; + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000); + const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404); +}; +struct IQXKScales2 { + IQXKScales2(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {} + template + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { + process(i, d, extra, _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), q8, accm, scales); + } + template + inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m512i * scales) const { + auto scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift)); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16)); + auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1)); + auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1); + auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1); + scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]); + scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]); + } + const __m256i eshift; + const __m256i min; + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + const __m512i shuffles[2] = { + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(), + _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3), + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(), + _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3) + }; +}; + +struct DequantizerIQ2K final : public BaseDequantizer { + DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales); + } + inline void prepare(const uint8_t * q2) { + bits.prepare(q2); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + return _mm_add_epi8(_mm_slli_epi16(scl, 1), m15); + } + Q2Bits bits; + const IQXKScales iqxk; + + const __m512i values; + const __m128i m15 = _mm_set1_epi8(-15); +}; + +struct DequantizerIQ3K final : public BaseDequantizer { + DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs, x[i].qh); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales); + } + inline void prepare(const uint8_t * q2, const uint8_t * qh) { + bits.prepare(q2); + auto h256 = _mm256_loadu_si256((const __m256i *)qh); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1); + bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), hmask)); + bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, hmask)); + bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), hmask)); + bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), hmask)); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); + const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask); + const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff); + return _mm_sign_epi8(scl, sch); + } + Q2Bits bits; + const IQXKScales2 iqxk; + + const __m512i values; + const __m512i hmask = _mm512_set1_epi8(4); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101); + const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff); + constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; +}; + +struct DequantizerIQ4K final : public BaseDequantizer { + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_512()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + } + inline void prepare(const uint8_t * q4) { + bits.prepare64(q4); + // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 + // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + + Q4Bits bits; + const IQXKScales2 iqxk; + const __m512i values; + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerIQ5K final : public BaseDequantizer { + DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); } + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs, x[i].qh); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + } + inline void prepare(const uint8_t * q4, const uint8_t * qh) { + bits.prepare64(q4); + auto h256 = _mm256_loadu_si256((const __m256i *)qh); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1); + auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); + auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); + bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]); + bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]); + hbits = _mm512_srli_epi16(hbits, 4); + m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); + m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); + bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]); + bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]); + // We now have in bits.valuse[0]: 0...31, 64...95 + // bits.valuse[1]: 32..63, 96..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]); + bits.values[0] = tmp; + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]); + bits.values[2] = tmp; + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + static void load_values(__m512i * values) { + static const uint8_t kvalues_iq5nl[32] = { + 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127, + 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249, + }; + auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1); + auto values256_1 = MM256_SET_M128I(values128_1, values128_1); + auto values256_2 = MM256_SET_M128I(values128_2, values128_2); + values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1); + values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1); + } + + Q4Bits bits; + const IQXKScales2 iqxk; + __m512i values[2]; + const __m512i hmask1 = _mm512_set1_epi8(1); + const __m512i hmask2 = _mm512_set1_epi8(2); + const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerIQ6K final : public BaseDequantizer { + DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); } + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs, x[i].qh); + auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales); + iqxk.process(i, d, x[i].extra, _mm256_cvtepi8_epi16(scales8), q8, accm, scales); + } + inline __m512i make_one(__m512i l, __m512i h) const { + auto p = _mm512_shuffle_epi8(values[0], l); + p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[0]), masks[0]), values[1], l); + p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[1]), masks[1]), values[2], l); + p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[2]), masks[2]), values[3], l); + return p; + } + inline void prepare(const uint8_t * q4, const uint8_t * qh) { + bits.prepare64(q4); + auto h256_1 = _mm256_loadu_si256((const __m256i *)qh + 0); + auto h256_2 = _mm256_loadu_si256((const __m256i *)qh + 1); + auto h1 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_1), _mm256_srli_epi16(h256_1, 4), 1); + auto h2 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_2), _mm256_srli_epi16(h256_2, 4), 1); + bits.values[0] = make_one(bits.values[0], h1); + bits.values[1] = make_one(bits.values[1], _mm512_srli_epi16(h1, 2)); + bits.values[2] = make_one(bits.values[2], h2); + bits.values[3] = make_one(bits.values[3], _mm512_srli_epi16(h2, 2)); + // We now have in bits.valuse[0]: 0...31, 64...95 + // bits.valuse[1]: 32..63, 96..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]); + bits.values[0] = tmp; + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]); + bits.values[2] = tmp; + } + static void load_values(__m512i * values) { + static const uint8_t kvalues_iq6nl[64] = { + 1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74, + 77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123, + 126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172, + 175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249, + }; + for (int k = 0; k < 4; ++k) { + auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k); + auto values256 = MM256_SET_M128I(values128, values128); + values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(values256), values256, 1); + } + } + + Q4Bits bits; + IQXKScales2 iqxk; + __m512i values[4]; + __m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) }; + const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); +}; + +template +inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) { + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); +} + +template +inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) { + auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); + auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( + _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), + values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)), + values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); +} + +template +static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx); + + __m256 accm[nrc_y]; + __m512 accd[nrc_y]; + __m512i scales[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); + for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.new_block(i, q8, accm, scales); + + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (std::is_same_v) { + auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); + auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( + _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), + deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)), + deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } else { + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } + + } +} + +template +static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx); + + __m256 accm[nrc_y]; + __m512 accd[nrc_y]; + __m512i scales[4]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); + for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.new_block(i, q8, accm, scales); + + for (int iy = 0; iy < nrc_y; ++iy) { + const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(), + p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } + + } +} + +template +static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + constexpr int k_nx = 2; + + Q8<1> q8(info); + + Dequantizer deq1(vx, bx); + Dequantizer deq2(vx, bx); + + Dequantizer * deq[k_nx]; + deq[0] = &deq1; + deq[1] = &deq2; + + __m512i scales[2*k_nx]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + auto accd = _mm512_setzero_ps(); + auto accm = _mm256_setzero_ps(); + + for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix); + + for (int i = 0; i < nb/k_nx; ++i) { + + for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); + + if constexpr (std::is_same_v) { + for (int kx = 0; kx < k_nx; ++kx) { + compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd); + } + } else { + for (int kx = 0; kx < k_nx; ++kx) { + compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); + } + } + + } + if (2*(nb/2) < nb) { + int i0 = 2*(nb/2); + deq[0]->new_block(i0, q8, &accm, scales); + if constexpr (std::is_same_v) { + compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd); + } else { + compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); + } + } + + if constexpr (std::is_same_v) { + info.store(ix, 0, _mm512_reduce_add_ps(accd)); + } else { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); + info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + } + } +} + +#else +// ===================================== Vanilla AVX2 ===================================== + +struct Q4Bits { + inline void prepare(const uint8_t * q4, int j) { + auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); + values[0] = _mm256_and_si256(q4bits, ml); + values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); + values[2] = _mm256_and_si256(q4bits, ml); + values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + } + inline void prepare64(const uint8_t * q4, int j) { + auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); + values[0] = _mm256_and_si256(q4bits, ml); + values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); + values[1] = _mm256_and_si256(q4bits, ml); + values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + } + inline void prepare16(const uint8_t * q4, int j) { + values[0] = dequant16(q4 + 64*j + 0); + values[1] = dequant16(q4 + 64*j + 16); + values[2] = dequant16(q4 + 64*j + 32); + values[3] = dequant16(q4 + 64*j + 48); + } + inline __m256i dequant16(const uint8_t * qs) const { + const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); + const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128); + return _mm256_and_si256(ml, aux256); + } + __m256i values[4]; + const __m256i ml = _mm256_set1_epi8(0xf); +}; + +struct Q2Bits { + inline void prepare(const uint8_t * q2, int j) { + auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j); + values[0] = _mm256_and_si256(q2bits, ml); + values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + } + __m256i values[4]; + const __m256i ml = _mm256_set1_epi8(0x03); +}; + +struct HighBit5 { + inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); } + inline void apply(Q4Bits& bits, bool do_shift) { + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + if (do_shift) { + hbits = _mm256_srli_epi16(hbits, 4); + } + } + const __m256i mh = _mm256_set1_epi8(0x10); + __m256i hbits; +}; + +struct HighBit3 { + inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); } + inline void apply(Q2Bits& bits, bool do_shift) { + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + if (do_shift) { + hbits = _mm256_srli_epi16(hbits, 4); + } + } + const __m256i mh = _mm256_set1_epi8(0x04); + __m256i hbits; +}; + +struct DequantizerQ4K final : public BaseDequantizer { + DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + d = GGML_FP16_TO_FP32(x[i].d); + return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + + Q4Bits bits; + Scales8K s8k; +}; + +__m256i load_iq4nl_values() { + static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); + return MM256_SET_M128I(val128, val128); +} + +struct DequantizerIQ4XS final : public BaseDequantizer { + DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {} + template + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h); + s8k.accum_mins(scales128, q8, i, -128.f*d, accd); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + + Q4Bits bits; + Scales8K s8k; + ScaleIQ4XS siq4; + const __m256i values; +}; + +struct IQXKScales { + IQXKScales(int8_t shift, int8_t min_val) : min(_mm256_set1_epi16(min_val)), eshift(_mm_set1_epi8(shift)) {} + template + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m256i * scales) const { + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)); + process(i, d, extra, scales16, q8, accm, scales); + //auto extra128 = _mm_set1_epi16(extra); + //extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + //extra128 = _mm_and_si128(extra128, eshift); + //extra128 = _mm_shuffle_epi8(extra128, eshuffle); + //auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(min, _mm256_cvtepi8_epi16(extra128))); + //for (int iy = 0; iy < Q8::nrc_y; ++iy) { + // const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i)); + // accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + //} + //prepare_scales_16(scales16, scales); + } + template + inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m256i * scales) const { + auto extra128 = _mm_set1_epi16(extra); + extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + extra128 = _mm_and_si128(extra128, eshift); + extra128 = _mm_shuffle_epi8(extra128, eshuffle); + auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(min, _mm256_cvtepi8_epi16(extra128))); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + prepare_scales_16(scales16, scales); + } + + const __m256i min; + const __m128i eshift; + const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); +}; + +struct DequantizerIQ2K final : public BaseDequantizer { + DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(5, -32), values(load_values()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + static inline __m256i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + return MM256_SET_M128I(val128, val128); + } + inline __m128i make_scales(const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + return _mm_add_epi8(_mm_slli_epi16(scl, 1), m15); + } + + Q2Bits bits; + const IQXKScales iqxk; + const __m256i values; + const __m128i m15 = _mm_set1_epi8(-15); + const __m128i maskl = _mm_set1_epi8(0xf); +}; + +struct DequantizerIQ3K final : public BaseDequantizer { + DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales); + hbits = _mm256_loadu_si256((const __m256i *)x[i].qh); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto h256 = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4); + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(h256, 2), hmask)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(h256, 1), hmask)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(h256, hmask)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(h256, 1), hmask)); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + static inline __m256i load_values() { + static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl); + return MM256_SET_M128I(val128, val128); + } + inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); + const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask); + const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff); + return _mm_sign_epi8(scl, sch); + } + + Q2Bits bits; + const IQXKScales iqxk; + const __m256i values; + __m256i hbits; + const __m256i hmask = _mm256_set1_epi8(4); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101); + const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff); + constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; +}; + +struct DequantizerIQ4K final : public BaseDequantizer { + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + + Q4Bits bits; + const IQXKScales iqxk; + const __m256i values; + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerIQ5K final : public BaseDequantizer { + DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); } + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + hbits = _mm256_loadu_si256((const __m256i *)x[i].qh); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4); + for (int k = 0; k < 4; ++k) { + auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh); + auto q5vl = _mm256_or_si256(bits.values[k], qh); + auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh)); + bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + } + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + static void load_values(__m256i * values) { + static const uint8_t kvalues_iq5nl[32] = { + 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127, + 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249, + }; + auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1); + values[0] = MM256_SET_M128I(values128_1, values128_1); + values[1] = MM256_SET_M128I(values128_2, values128_2); + } + + Q4Bits bits; + const IQXKScales iqxk; + __m256i hbits; + __m256i values[2]; + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); + const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing +}; + +struct DequantizerIQ6K final : public BaseDequantizer { + DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); } + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales); + auto scales16 = _mm256_cvtepi8_epi16(scales8); + iqxk.process(i, d, x[i].extra, scales16, q8, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); + for (int k = 0; k < 4; ++k) { + bits.values[k] = make_one(bits.values[k], hbits); + hbits = _mm256_srli_epi16(hbits, 2); + } + } + inline __m256i make_one(__m256i l, __m256i hbits) const { + auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3); + auto h1 = _mm256_andnot_si256(mask4, hbits); + auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1); + auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2); + auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff)); + return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)), + _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))), + _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)), + _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l)))); + } + static void load_values(__m256i * values) { + static const uint8_t kvalues_iq6nl[64] = { + 1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74, + 77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123, + 126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172, + 175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249, + }; + for (int k = 0; k < 4; ++k) { + auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k); + values[k] = MM256_SET_M128I(values128, values128); + } + } + + Q4Bits bits; + const IQXKScales iqxk; + __m256i values[4]; + const __m256i mh1 = _mm256_set1_epi8(1); + const __m256i mh2 = _mm256_set1_epi8(2); + const __m256i mh3 = _mm256_set1_epi8(3); + const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing +}; + +struct DequantizerQ5K final : public BaseDequantizer { + DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + d = GGML_FP16_TO_FP32(x[i].d); + hbits.load(x[i].qh); + return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + hbits.apply(bits, j == 0); + } + + Q4Bits bits; + HighBit5 hbits; + Scales8K s8k; +}; + +template +inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d, + __m256 * accm, __m256i * scales) { + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + process_mins_16(all_scales, q8, i, d, accm); + prepare_scales_16(all_scales, scales); +} + +struct DequantizerQ3K final : public BaseDequantizer { + DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + hbits.load(x[i].hmask); + process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + hbits.apply(bits, j == 0); + } + + Q2Bits bits; + HighBit3 hbits; + ScaleQ3 sc3; + + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerQ2K final : public BaseDequantizer { + DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm); + prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + + Q2Bits bits; + + const __m128i m4 = _mm_set1_epi8(0xf); +}; + +struct DequantizerQ6K final : public BaseDequantizer { + DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare64(x[i].ql, j); + auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh)); + } + + Q4Bits bits; + const __m256i mh = _mm256_set1_epi8(0x30); +}; + +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + + Q2Bits bits; +}; + + +template +IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Q8 q8(info); + DequantizerIQ2TN deq(vx, bx); + + __m256 accd[nrc_y]; + const auto m1 = _mm256_set1_epi16(1); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + __m256i sumi[nrc_y]; + deq.new_block(i); + + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) { + sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))); + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 3))), sumi[iy]); + } + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) { + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 4)), + _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 5))), sumi[iy]); + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 6)), + _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]); + sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i)); + } + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy]))); + } + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + +template +static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Q8 q8(info); + + __m256i all_scales[2]; + __m256i scales[4]; + __m256 accd[nrc_y]; + + Dequantizer deq(vx, bx); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + deq.new_block(i, q8, accd, all_scales); + + __m256i sumi[nrc_y]; + + for (int j = 0; j < QK_K/128; ++j) { + deq.prepare(i, j); + set_scales_16(all_scales[j], scales); + multiply_add(deq.bits, scales, j, i, q8, sumi); + } + + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } + +} + +template +static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx); + + __m256 accd[nrc_y]; + __m256i scales[4]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + auto all_scales = deq.new_block(i, q8, accd); + + __m256i sumi[nrc_y]; + + for (int j = 0; j < QK_K/128; ++j) { + + deq.prepare(i, j); + + set_scales_8(all_scales, j, scales); + + multiply_add(deq.bits, scales, j, i, q8, sumi); + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); + accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + +#endif // Zen4 or vanilla AVX2 + +template +inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { + if (j == 0) { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); + auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); + auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); + auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); + sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2)); + sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4)); +#else + const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); + const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); + const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); + const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); + sumi[0] = _mm256_add_epi32(p1, p3); + sumi[1] = _mm256_add_epi32(p2, p4); +#endif + } else { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); + auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); + auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); + auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); + sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2)); + sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4)); +#else + const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); + const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); + const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); + const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); + sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3)); + sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4)); +#endif + } +} + +inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) { +#ifdef HAVE_FANCY_SIMD + auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100) + : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908); + scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); + scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4))); +#else + set_scales_8(all_scales, j, scales); +#endif +} + +inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) { +#ifdef HAVE_FANCY_SIMD + auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100); + scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); + scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8))); +#else + set_scales_16(all_scales, scales); +#endif +} + +template +static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_K; + Q8<1> q8(info); + Dequantizer deq(vx, bx); + __m256i scales[2]; + __m256i q8_quants[4]; + for (int ix = 0; ix < nrc_x; ++ix) { + + __m256 accd = _mm256_setzero_ps(); + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + __m256i sumi[2], all_scales[Dequantizer::num_blocks/8]; + deq.new_block(i, all_scales); + + for (int j = 0; j < QK_K/128; ++j) { + deq.prepare(i, j, q8, q8_quants); + if constexpr (Dequantizer::num_blocks == 8) { + set_scales_8_iq(j, all_scales[0], scales); + } else { + set_scales_16_iq(all_scales[j], scales); + } + multiply_add_1(j, deq.bits, scales, q8_quants, sumi); + } + accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd); + } + + info.store(ix, 0, hsum_float_8(accd)); + } +} + +// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below, +// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__). +// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon, +// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG) +// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)? +//template +//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { +//#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +// for (int iy = 0; iy < Q8::nrc_y; ++iy) { +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0))); +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1))); +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2))); +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3))); +// } +//#else +// for (int iy = 0; iy < Q8::nrc_y; ++iy) { +// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0))); +// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1))); +// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2))); +// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3))); +// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); +// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); +// } +//#endif +//} + +template +static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_K; + Q8 q8(info); + Dequantizer deq(vx, bx); + __m256i scales[4]; + __m256 accd[nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8]; + //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256(); + __m256i mins; + float dmin = deq.new_block(i, all_scales, mins); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, bsums); + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); + } + + for (int j = 0; j < QK_K/128; ++j) { + deq.prepare(i, j); + if constexpr (Dequantizer::num_blocks == 8) { + set_scales_8(all_scales[0], j, scales); + } else { + set_scales_16(all_scales[j], scales); + } + //multiply_add_iq(deq.bits, scales, j, i, q8, sumi); + multiply_add(deq.bits, scales, j, i, q8, sumi); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); + accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + +template struct Q8_K64 { + + constexpr static int nrc_y = nrc; + + Q8_K64(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + const float * dptr = (const float *)info.src1_row(iy); + std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 4); + } + } + + inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); } + + float d[4*nrc_y]; + const int8_t * y[nrc_y]; +}; + +struct DequantizerIQ1BN { + const __m256i m1_8 = _mm256_set1_epi8(1); + static __m256i load_shuffle(int i) { + static const uint8_t data[128] = { + 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255, + 3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255, + 6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255, + 9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255, + }; + return _mm256_loadu_si256((const __m256i*)data + i); + } + const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) }; + const __m256i mult[4] = { + _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + }; + const __m256i m3 = _mm256_set1_epi16(3); +#ifdef HAVE_FANCY_SIMD + const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); +#endif + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const { + auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! + auto data = MM256_SET_M128I(data128, data128); + auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3); + auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3); + auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3); + auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3); +#ifdef HAVE_FANCY_SIMD + v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); + v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); +#else + v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); + v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); +#endif + } + +}; + +template +IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + Q8_K64 q8(info); + DequantizerIQ1BN deq; + __m256i accd[nrc_y]; + __m256i val[4]; + +#if !(defined __AVX512VNNI__ && defined __AVX512VL__) + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); + + for (int ix = 0; ix < nrc_x; ++ix) { + + x = (const block_iq1_bn *)((const char *)vx + ix*bx); + + if constexpr (nrc_y == 1) { + __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); + for (int i = 0; i < nb/2; ++i) { + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]); + acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2); + acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1)); + acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2)); +#endif + } + accd[0] = _mm256_add_epi32(acc1, acc2); + } + else { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); + + for (int i = 0; i < nb/2; ++i) { + + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); + + for (int iy = 0; iy < nrc_y; ++iy) { +#if defined __AVX512VNNI__ && defined __AVX512VL__ + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); + accd[iy] = _mm256_add_epi32(dot1, accd[iy]); +#endif + } + } + } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare_iq1bn_quants(x + i, val[0], val[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); +#else + auto dot = _mm256_madd_epi16(m1_16, + _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); +#endif + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto vd = q8.scale(iy); + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); + auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); + info.store(ix, iy, hsum_float_4(sumf)); + } + + } +} + +struct DequantizeIQ2BN final : public BaseDequantizer { + DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const { + auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs); + auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2); + make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0); + make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2); + } + IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const { + val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8); + } + IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const { + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val); + } + const __m256i m1_8 = _mm256_set1_epi8(1); + const __m256i mf_8 = _mm256_set1_epi8(16); + const __m256i mask2 = _mm256_set1_epi8(0x03); + const __m256i mask3 = _mm256_set1_epi8(0x30); +}; + +template +IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + Q8_K64 q8(info); + DequantizeIQ2BN deq(vx, bx); + __m256i accd[nrc_y]; + __m256i val[4]; + +#if !(defined __AVX512VNNI__ && defined __AVX512VL__) + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + if constexpr (nrc_y == 1) { + __m256i acc[2] = {}; + for (int i = 0; i < nb/2; ++i) { + deq.prepare4(i, val); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), + deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])); + acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), + deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); + acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); +#endif + } + accd[0] = _mm256_add_epi32(acc[0], acc[1]); + } + else { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); + + for (int i = 0; i < nb/2; ++i) { + deq.prepare4(i, val); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), + _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); + accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot; +#endif + } + } + } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare2(i, val); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); +#else + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); + accd[iy] = _mm256_add_epi32(dot1, accd[iy]); +#endif + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto vd = q8.scale(iy); + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); + auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); + info.store(ix, iy, hsum_float_4(sumf)); + } + } +} + +template +static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + if constexpr (nrc_y == 1) { + mul_mat_qX_K_q8_K_IQ_1(n, vx, bx, info, nrc_x); + } else { + mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); + } +} + +//#ifdef HAVE_FANCY_SIMD +// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster +// compared to the vanilla AVX2 version below. +//struct IndexHelperIQ3S { +// union index_t { +// __m256i vec; +// uint16_t val[16]; +// }; +// inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { +// auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); +// const __mmask16 * m16 = (const __mmask16 *)qh; +// index_t idx; +// idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset); +// values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]], +// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]); +// values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]], +// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]); +// } +// const __m256i offset = _mm256_set1_epi16(256); +//}; +//#else +struct IndexHelperIQ3S { + union index_t { + __m256i vec; + uint32_t val[8]; + }; + inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { + index_t idx; + auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); + auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8))); + idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + } + const __m256i idx_mask = _mm256_set1_epi32(256); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); +}; +//#endif + +struct DequantizerIQ3S final : public BaseDequantizer { + DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + constexpr static int num_blocks = 8; + + inline __m128i make_scales(int i, float& dd) const { + dd = GGML_FP16_TO_FP32(x[i].d); + uint32_t aux32[2]; + std::memcpy(aux32, x[i].scales, 4); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400)); + auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8)); + return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1)); + } + inline void new_block(int i, __m256i * scales) { + auto scales16 = make_scales(i, d); + scales[0] = MM256_SET_M128I(scales16, scales16); + } + inline float new_block(int i, __m256i * scales, __m256i& mins) { + auto scales16 = make_scales(i, d); + mins = scb.shuffle(scales16); + scales[0] = MM256_SET_M128I(scales16, scales16); + return -minv*d; + } + + inline void prepare(int i, int j) { + prepare_unsigned(i, j); + sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, bits.values); + for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi8(bits.values[k], min_value); + } + inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { + prepare_unsigned(i, j); + for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); + sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants); + } + + inline void prepare_unsigned(int i, int j) { + auto qs = x[i].qs + 32*j; + auto qh = x[i].qh + 4*j; + helper.make2(qs+ 0, qh+0, bits.values+0); + helper.make2(qs+16, qh+2, bits.values+2); + } + + constexpr static int minv = 16; + + SimpleBits bits; + SignHelper sh; + Scales8KBase scb; + IndexHelperIQ3S helper; + const __m256i min_value = _mm256_set1_epi8(minv); + +}; + +struct EvenSignHelper { +#ifdef HAVE_FANCY_SIMD + union sbits_t { + __m128i vec; + __mmask32 mask[4]; + }; + IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const { + aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask); + auto pcnt = _mm256_popcnt_epi32(aux); + sbits_t sbits; + sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); + values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]); + values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]); + //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); + //const __mmask32 * m32 = (const __mmask32 *)&sign_bits; + //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]); + //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]); + } + const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0); + const __m256i mask = _mm256_set1_epi32(127); + const __m256i mone = _mm256_set1_epi32(1); +#else + inline void sign_value(uint32_t aux32, __m256i& value) const { + auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], + keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]); + value = _mm256_sign_epi8(value, signs); + } +#endif +}; + +struct DequantizerIQ3XXS final : public BaseDequantizer { + DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + constexpr static int num_blocks = 8; + + inline __m128i prepare_scales(int i) { + d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4)); + auto scales32 = _mm256_srli_epi32(tmp, 28); + scales32 = _mm256_or_si256(_mm256_slli_epi32(scales32, 1), _mm256_set1_epi32(1)); + return _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1)); + } + + inline void new_block(int i, __m256i * scales) { + auto scales16 = prepare_scales(i); + scales[0] = MM256_SET_M128I(scales16, scales16); + } + inline float new_block(int i, __m256i * scales, __m256i& mins) { + auto scales16 = prepare_scales(i); + mins = scb.shuffle(scales16); + scales[0] = MM256_SET_M128I(scales16, scales16); + return -d*minv; + } + + inline static __m256i make_quants(const uint8_t * qs) { + return _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]], + iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]); + } + inline static void make4_unsigned(const uint8_t * qs, __m256i * values) { + values[0] = make_quants(qs+ 0); + values[1] = make_quants(qs+ 8); + values[2] = make_quants(qs+16); + values[3] = make_quants(qs+24); + } + + IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * signs, __m256i * values) const { +#ifdef HAVE_FANCY_SIMD + esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(signs[2] | (signs[3] << 16)), _mm_set1_epi32(signs[0] | (signs[1] << 16))), values); +#else + esh.sign_value(signs[0] | (signs[1] << 16), values[0]); + esh.sign_value(signs[2] | (signs[3] << 16), values[1]); +#endif + } + + inline void prepare(int i, int j) { + auto qs = x[i].qs + 32*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j; + make4_unsigned(qs, bits.values); + sign_2_values(signs+0, bits.values+0); + sign_2_values(signs+4, bits.values+2); + for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi32(bits.values[k], min_value); + } + inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { + for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); + auto qs = x[i].qs + 32*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j; + make4_unsigned(qs, bits.values); + sign_2_values(signs+0, q8_quants+0); + sign_2_values(signs+4, q8_quants+2); + } + + constexpr static int minv = 64; + + SimpleBits bits; + Scales8KBase scb; + EvenSignHelper esh; + const __m256i min_value = _mm256_set1_epi8(minv); + +}; + +struct DequantizerIQ2S final : public BaseDequantizer { + DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + constexpr static int num_blocks = 16; + + inline __m256i load_scales(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales); + auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf)); + auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1)); + return _mm256_cvtepi8_epi16(scales8); + } + inline static void prepare_scales(const __m256i& all, __m256i * scales) { + auto scales_l = _mm256_castsi256_si128(all); + auto scales_h = _mm256_extractf128_si256(all, 1); + scales[0] = MM256_SET_M128I(scales_l, scales_l); + scales[1] = MM256_SET_M128I(scales_h, scales_h); + } + + inline void new_block(int i, __m256i * scales) { + prepare_scales(load_scales(i), scales); + } + inline float new_block(int i, __m256i * scales, __m256i& mins) { + mins = load_scales(i); + prepare_scales(mins, scales); + return -d*minv; + } + + union index_t { + __m256i vec; + uint32_t val[8]; + }; + + inline static void make2(const uint8_t * qs, const uint8_t * qh, const __m256i& idx_shift, const __m256i& idx_mask, __m256i * values) { + auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); + auto idx_h = MM256_SET_M128I(_mm_set1_epi32(qh[1]), _mm_set1_epi32(qh[0])); + index_t idx; + idx.vec = _mm256_or_si256(idx_l, _mm256_and_si256(_mm256_sllv_epi32(idx_h, idx_shift), idx_mask)); + values[0] = _mm256_set_epi64x(iq2s_grid[idx.val[3]], iq2s_grid[idx.val[2]], iq2s_grid[idx.val[1]], iq2s_grid[idx.val[0]]); + values[1] = _mm256_set_epi64x(iq2s_grid[idx.val[7]], iq2s_grid[idx.val[6]], iq2s_grid[idx.val[5]], iq2s_grid[idx.val[4]]); + } + inline static void make2_signed(const SignHelper& sh, const uint8_t * qs, const uint8_t * qh, const uint16_t * sidx, + const __m256i& idx_shift, const __m256i& idx_mask, const __m256i& min_value, __m256i * values) { + make2(qs, qh, idx_shift, idx_mask, values); + values[0] = _mm256_add_epi8(sh.sign_value(sidx+0, values[0]), min_value); + values[1] = _mm256_add_epi8(sh.sign_value(sidx+2, values[1]), min_value); + } + + inline void prepare(int i, int j) { + auto qs = x[i].qs + 16*j; + auto qh = x[i].qh + 4*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j; + make2_signed(sh, qs+0, qh+0, signs+0, idx_shift, idx_mask, min_value, bits.values+0); + make2_signed(sh, qs+8, qh+2, signs+4, idx_shift, idx_mask, min_value, bits.values+2); + } + inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { + auto qs = x[i].qs + 16*j; + auto qh = x[i].qh + 4*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j; + make2(qs+0, qh+0, idx_shift, idx_mask, bits.values+0); + make2(qs+8, qh+2, idx_shift, idx_mask, bits.values+2); + q8_quants[0] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+0), sh.make_signs(signs[0] | (signs[1] << 16))); + q8_quants[1] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+1), sh.make_signs(signs[2] | (signs[3] << 16))); + q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16))); + q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16))); + } + + constexpr static int minv = 43; + + SimpleBits bits; + SignHelper sh; + const __m256i idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8); + const __m256i idx_mask = _mm256_set1_epi32(0x300); + const __m256i min_value = _mm256_set1_epi8(minv); + +}; + +struct DequantizerIQ2XS final : public BaseDequantizer { + DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + constexpr static int num_blocks = 16; + + inline __m256i load_scales(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales); + auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf)); + auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1)); + return _mm256_cvtepi8_epi16(scales8); + } + inline static void prepare_scales(const __m256i& all, __m256i * scales) { + auto scales_l = _mm256_castsi256_si128(all); + auto scales_h = _mm256_extractf128_si256(all, 1); + scales[0] = MM256_SET_M128I(scales_l, scales_l); + scales[1] = MM256_SET_M128I(scales_h, scales_h); + } + + inline void new_block(int i, __m256i * scales) { + prepare_scales(load_scales(i), scales); + } + inline float new_block(int i, __m256i * scales, __m256i& mins) { + mins = load_scales(i); + prepare_scales(mins, scales); + return -d*minv; + } + + struct Helper { + const __m256i mone = _mm256_set1_epi8(1); + const __m256i mask = _mm256_set1_epi64x(0x8040201008040201); + //const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080); + const __m256i bhelper = load_bhelper(); + const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000); + const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808); + static __m256i load_bhelper() { + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + return _mm256_loadu_si256((const __m256i*)k_bit_helper); + } + }; + + union index_t { + __m256i vec; + uint16_t val[8]; + }; + + inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) { + index_t idx; + idx.vec = _mm256_and_si256(data, mask); + values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]); + values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]); + values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]); + values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]); + } + inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask, + const __m256i& mone, __m256i& value) { + auto signs = _mm256_shuffle_epi8(sign_bits, shuffle); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask); + value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone)); + } + inline void sign_values(const __m256i& data, __m256i * values) const { +#ifdef HAVE_FANCY_SIMD + auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9)); + auto pcnt = _mm_popcnt_epi8(partial_bits); + auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7)); + const __mmask32 * m32 = (const __mmask32 *)&full_bits; + auto zero = _mm256_setzero_si256(); + values[0] = _mm256_mask_sub_epi8(values[0], m32[0], zero, values[0]); + values[1] = _mm256_mask_sub_epi8(values[1], m32[1], zero, values[1]); + values[2] = _mm256_mask_sub_epi8(values[2], m32[2], zero, values[2]); + values[3] = _mm256_mask_sub_epi8(values[3], m32[3], zero, values[3]); +#else + auto psb1 = _mm256_srli_epi16(data, 9); + auto psb2 = _mm256_srli_epi16(data, 13); + auto psbc = _mm256_xor_si256(psb1, psb2); + auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc); + auto full = _mm256_or_si256(psb1, oddb); + auto full_l = _mm256_castsi256_si128(full); + auto full_h = _mm256_extractf128_si256(full, 1); + auto full_1 = MM256_SET_M128I(full_l, full_l); + auto full_2 = MM256_SET_M128I(full_h, full_h); + sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]); + sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]); + sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]); + sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]); +#endif + } + inline void make4_signed(const uint16_t * qs, const __m256i& m511, + const __m256i& min_value, __m256i * values) const { + auto q2 = _mm256_loadu_si256((const __m256i *)qs); + make4(q2, m511, values); + sign_values(q2, values); + for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value); + } + inline void make4(const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) const { + auto q2 = _mm256_loadu_si256((const __m256i *)qs); + make4(q2, m511, values); + sign_values(q2, q8); + } + + inline void prepare(int i, int j) { + make4_signed(x[i].qs + 16*j, idx_mask, min_value, bits.values); + } + inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { + for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); + make4(x[i].qs + 16*j, idx_mask, bits.values, q8_quants); + } + + constexpr static int minv = 43; + + SimpleBits bits; +#ifndef HAVE_FANCY_SIMD + Helper helper; +#endif + const __m256i idx_mask = _mm256_set1_epi16(511); + const __m256i min_value = _mm256_set1_epi8(minv); + +}; + +struct DequantizerIQ2XXS final : public BaseDequantizer { + DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + constexpr static int num_blocks = 8; + + union Data { + __m256i vec; + uint32_t val[8]; + }; + + inline __m128i load_scales(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + const uint16_t * a16 = (const uint16_t *)x[i].qs; + auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12); + return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1)); + } + + inline void new_block(int i, __m256i * scales) { + auto sc16 = load_scales(i); + scales[0] = MM256_SET_M128I(sc16, sc16); + } + inline float new_block(int i, __m256i * scales, __m256i& mins) { + auto sc16 = load_scales(i); + mins = scb.shuffle(sc16); + scales[0] = MM256_SET_M128I(sc16, sc16); + return -d*minv; + } + + inline static void make4(const uint32_t * aux32, __m256i * values) { + const uint8_t * aux8 = (const uint8_t *)aux32; + values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]); + values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]); + values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]); + values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]); + } + + IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const { +#ifdef HAVE_FANCY_SIMD + esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0); + esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2); +#else + esh.sign_value(aux32[1], values[0]); + esh.sign_value(aux32[3], values[1]); + esh.sign_value(aux32[5], values[2]); + esh.sign_value(aux32[7], values[3]); +#endif + } + inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const { + make4(aux32, values); + sign_values(aux32, values); + for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value); + } + inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const { + make4(aux32, values); + sign_values(aux32, q8); + } + inline void prepare(int i, int j) { + Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + make4_signed(data.val, min_value, bits.values); + } + inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { + for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); + Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + make4(data.val, bits.values, q8_quants); + } + + constexpr static int minv = 43; + SimpleBits bits; + Scales8KBase scb; + EvenSignHelper esh; + const __m256i min_value = _mm256_set1_epi8(minv); + const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1); +}; + +// +// ============================== Legacy quants +// + +struct DotHelper { + const __m256i m1 = _mm256_set1_epi16(1); +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + inline __m256i dot(__m256i x, __m256i y) const { + return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y); + } +#else + inline __m256i dot(__m256i x, __m256i y) const { + return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y)); + } +#endif +}; + +struct SignedDot { + DotHelper helper; + inline __m256i compute(__m256i x, __m256i y) const { + return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x)); + } +}; +struct UnsignedDot { + DotHelper helper; + inline __m256i compute(__m256i x, __m256i y) const { + return helper.dot(x, y); + } +}; + +template struct Sum4 { + Dot dot; + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0 + const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1 + const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2 + const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3 + if constexpr (can_pack) { + const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1 + const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3 + return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3 + } else { + // Note to myself: this is much faster than using _mm256_hadd_epi32() + auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1 + auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3 + return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 + } + } +}; +// If I use this, it negatively impacts q4_1/q5_1 performance. +//template struct Sum4 { +// Dot dot; +// inline __m256i compute(const __m256i * qx, const Q8 * y) const { +// const Q8x4 * y4 = (const Q8x4 *)y; +// const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0 +// const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1 +// const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2 +// const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3 +// auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1 +// auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3 +// return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 +// } +//}; + +struct ScaleHelperQ8_0 { + inline __m128 prepare4(const block_q8_0 * y) { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y; + return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d)); + } + inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) { + return _mm_mul_ps(other_scales, prepare4(y)); + } + template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } + template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } +}; + +struct ScaleHelperQ_0 { + ggml_half scales8[4]; + template + inline __m128 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; + return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); + } + template + inline __m128 prepare4(__m128 other_scales, const Q * y) { + return _mm_mul_ps(other_scales, prepare4(y)); + } + template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } + template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } +}; + +struct ScaleHelperQ8_1 { + template + inline __m256 prepare4(const Q * y) { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y; + return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)y4->d)); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); + } +}; + +struct ScaleHelperQ_1 { + uint32_t scales8[4]; + const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); + + template + inline __m256 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) { + // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers + // complain that this breaks strict-aliasing rules. + memcpy(scales8 + j, &y[j].d, sizeof(uint32_t)); + } + return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle)); + } + + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); + } +}; + +struct MinusType0 { + inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } + inline float compute(float d, int) const { return d; } + inline float result(__m256 acc, int) const { return hsum_float_8(acc); } +}; + +template struct MinusType1 { + __m128 accm[nrc_y]; + MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); } + inline __m256 compute(__m256 dm, int iy) { + const __m128 d = _mm256_castps256_ps128(dm); + const __m128 m = _mm256_extractf128_ps(dm, 1); + accm[iy] = _mm_add_ps(accm[iy], m); + return _mm256_set_m128(d, d); + } + inline float compute(const std::pair& dm, int iy) { + accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f)); + return dm.first; + } + inline float result(__m256 acc, int iy) const { + const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); + return hsum_float_4(_mm_add_ps(sum, accm[iy])); + } +}; + +template struct AccumT { + __m256 acc[nrc_y]; + Minus accm; + AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); } + template + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, accm.result(acc[iy], iy)); + //s[iy*bs] = accm.result(acc[iy], iy); + } + } +}; + +template +using AccumType0 = AccumT; + +template +using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; + +using Sum4Type0 = Sum4; +using Sum4Type1 = Sum4; +using Sum4TypeQ80 = Sum4; + +template +void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ++ix) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + +template +void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_0, block_q8_0, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_0, block_q8_0, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + +template +void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_1, block_q8_1, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_1, block_q8_1, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + +struct Dequantizer4bit { + const __m256i m4 = _mm256_set1_epi8(0xf); + inline __m256i dequant(const uint8_t * qs) const { + const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); + return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4); + } +}; + +struct Q8_0_Dequantizer { + inline __m256i dequant(const block_q8_0 * x) const { + return _mm256_loadu_si256((const __m256i *)x->qs); + } +}; + +struct Q4_0_Dequantizer { + Dequantizer4bit b4; + const __m256i m8 = _mm256_set1_epi8(-8); + inline __m256i dequant(const block_q4_0 * x) const { + return _mm256_add_epi8(b4.dequant(x->qs), m8); + } +}; + +struct IQ4_NL_Dequantizer { + Dequantizer4bit b4; + const __m256i values = load_values(); + inline __m256i dequant(const block_iq4_nl * x) const { + return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); + } + static __m256i load_values() { + static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + auto aux = _mm_loadu_si128((const __m128i *)iq4nl_values); + return MM256_SET_M128I(aux, aux); + } +}; + +struct Q4_1_Dequantizer { + Dequantizer4bit b4; + inline __m256i dequant(const block_q4_1 * x) const { + return b4.dequant(x->qs); + } +}; + +struct HBitDequantizer { + const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + const __m256i minus1 = _mm256_set1_epi64x(-1); + inline __m256i to_bytes(const uint8_t * bits) const { + // Note: Data in all ggml quants is at least 2-byte aligned. + // => we can cast to uint16_t and use or on two consecutive entries + // which is faster than memcpy + const uint16_t * aux16 = (const uint16_t *)bits; + const uint32_t aux32 = aux16[0] | (aux16[1] << 16); + //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t)); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle); + bytes = _mm256_or_si256(bytes, mask); + return _mm256_cmpeq_epi8(bytes, minus1); + } +}; + +struct Q5_0_Dequantizer { + Dequantizer4bit b4; + HBitDequantizer hbit; + const __m256i mh = _mm256_set1_epi8((char)0xF0); + inline __m256i dequant(const block_q5_0 * x) const { + const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh); + return _mm256_or_si256(b4.dequant(x->qs), vqh); + } +}; + +struct Q5_1_Dequantizer { + Dequantizer4bit b4; + HBitDequantizer hbit; + const __m256i mh = _mm256_set1_epi8(0x10); + inline __m256i dequant(const block_q5_1 * x) const { + const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh); + return _mm256_or_si256(b4.dequant(x->qs), vqh); + } +}; + +template +struct Q_Unpacker { + Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {} + + const char * cx_0; + const Q * x; + size_t bx; + + Scales scales; + Dequantizer deq; + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + for (int j = 0; j < 4; ++j) { + qx[j] = deq.dequant(x + 4*i + j); + } + return scales.prepare4(x + 4*i); + } + inline auto set_block(int i) { + qx[0] = deq.dequant(x + i); + return scales.prepare1(x + i); + } +}; + +struct Q8_0_Unpacker final : public Q_Unpacker { + Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK8_0; } +}; +struct Q4_0_Unpacker final : public Q_Unpacker { + Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK4_0; } +}; +struct IQ4_NL_Unpacker final : public Q_Unpacker { + IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK4_NL; } +}; +struct Q5_0_Unpacker final : public Q_Unpacker { + Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK5_0; } +}; +struct Q4_1_Unpacker final : public Q_Unpacker { + Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4Type1; + inline static int block_size() { return QK4_1; } +}; +struct Q5_1_Unpacker final : public Q_Unpacker { + Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4Type1; + inline static int block_size() { return QK4_1; } +}; + +// float matrices - we handle f16 and f32, but only to f32 result + +struct QFBase { +#ifdef __AVX512F__ + constexpr static int k_step = 16; + using Data = __m512; + using Acc = __m512; + static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); } + static inline Data load(const float * x) { return _mm512_loadu_ps(x); } + static inline Acc acc(Acc prev, const Data& y, const Data& x) { + return _mm512_fmadd_ps(y, x, prev); + } + static inline Acc acc_first(const Data& y, const Data& x) { + return _mm512_mul_ps(y, x); + } + static inline float hsum(Acc acc) { + return _mm512_reduce_add_ps(acc); + } + template + static inline Data load4Floats(const Float * x) { + return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); + } +#else + constexpr static int k_step = 8; + using Data = __m256; + using Acc = __m256; + static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); } + static inline Data load(const float * x) { return _mm256_loadu_ps(x); } + static inline Acc acc(Acc prev, const Data& y, const Data& x) { + return _mm256_fmadd_ps(y, x, prev); + } + static inline Acc acc_first(const Data& y, const Data& x) { + return _mm256_mul_ps(y, x); + } + static inline float hsum(Acc acc) { + return hsum_float_8(acc); + } + template + static inline Data load4Floats(const Float * x) { + return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0); + } +#endif + static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); } + static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); } +}; +template struct QFT final : public QFBase { + constexpr static int nrc = nrc_in; + QFT(const DataInfo& info) { + for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy); + } + QFT(const char * cx, size_t bx) { + for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx); + } + IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } + IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); } + const Float * y[nrc]; +}; + +template +IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + assert(n%QFBase::k_step == 0); + int nb = n/QFBase::k_step; + int nb4 = n/4; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); + } + } + for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) { + yv = y.load_tail(0, i); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load_tail(ix, i); + acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load_tail(iy, i); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); +} + +// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done +// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in +// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. +template +void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QFBase::k_step == 0); +#ifdef __AVX512F__ + constexpr int k_nx = 5; +#else + constexpr int k_nx = 2; +#endif + const char * cx = (const char *)vx; + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; +#ifdef __AVX512F__ + case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + case 4: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; +#endif + } +} + +// +// Tiled Q8_0 x Q8_0 implementation. Not used as the templated legacy quant implementation +// above is faster. Left behind so we remember we tried. +// +template struct Q80 { + constexpr static int nrc_y = nrc; + Q80(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); + } + IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); } + IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); } + + const block_q8_0 * y[nrc_y]; +}; +inline __m256i mul_q80(__m256i x, __m256i y) { + auto ux = _mm256_sign_epi8(x, x); +#ifdef HAVE_FANCY_SIMD + return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x)); +#else + return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x))); +#endif +} +template +void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK8_0 == 0); + constexpr int k_nx = 4; + int nb = n/QK8_0; + Q80 q8(info); + const block_q8_0 * x[k_nx]; + float ds[k_nx]; + __m256 acc[k_nx*nrc_y]; + __m256i xv[k_nx]; + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + int ix0 = k_nx*ix; + for (int kx = 0; kx < k_nx; ++kx) { + x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx); + ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d); + xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto yv = q8.load1(iy, 0); + float d = q8.scale(iy, 0); + for (int kx = 0; kx < k_nx; ++kx) { + auto dot = mul_q80(yv, xv[kx]); + acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot)); + } + } + for (int i = 1; i < nb; ++i) { + for (int kx = 0; kx < k_nx; ++kx) { + ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d); + xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto yv = q8.load1(iy, i); + float d = q8.scale(iy, i); + for (int kx = 0; kx < k_nx; ++kx) { + auto dot = mul_q80(yv, xv[kx]); + acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx])); + } + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + // TODO: handle remaining rows +} + +template void MulMat::set_functions(MulMat& m) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_0_q8_0_T; + m.funcs[1] = mul_mat_qX_0_q8_0_T; + m.funcs[2] = mul_mat_qX_0_q8_0_T; + m.funcs[3] = mul_mat_qX_0_q8_0_T; + m.funcs[4] = mul_mat_qX_0_q8_0_T; + m.funcs[5] = mul_mat_qX_0_q8_0_T; + m.funcs[6] = mul_mat_qX_0_q8_0_T; + m.funcs[7] = mul_mat_qX_0_q8_0_T; + } + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_1_q8_1_T; + m.funcs[1] = mul_mat_qX_1_q8_1_T; + m.funcs[2] = mul_mat_qX_1_q8_1_T; + m.funcs[3] = mul_mat_qX_1_q8_1_T; + m.funcs[4] = mul_mat_qX_1_q8_1_T; + m.funcs[5] = mul_mat_qX_1_q8_1_T; + m.funcs[6] = mul_mat_qX_1_q8_1_T; + m.funcs[7] = mul_mat_qX_1_q8_1_T; + } + else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + m.funcs[0] = mul_mat_qX_K_q8_K_IQ; + m.funcs[1] = mul_mat_qX_K_q8_K_IQ; + m.funcs[2] = mul_mat_qX_K_q8_K_IQ; + m.funcs[3] = mul_mat_qX_K_q8_K_IQ; + m.funcs[4] = mul_mat_qX_K_q8_K_IQ; + m.funcs[5] = mul_mat_qX_K_q8_K_IQ; + m.funcs[6] = mul_mat_qX_K_q8_K_IQ; + m.funcs[7] = mul_mat_qX_K_q8_K_IQ; + } + else { +#ifdef HAVE_FANCY_SIMD + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512; + m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512; + } else { + m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1; + m.funcs[1] = mul_mat_qX_K_q8_K_AVX512; + m.funcs[2] = mul_mat_qX_K_q8_K_AVX512; + m.funcs[3] = mul_mat_qX_K_q8_K_AVX512; + m.funcs[4] = mul_mat_qX_K_q8_K_AVX512; + m.funcs[5] = mul_mat_qX_K_q8_K_AVX512; + m.funcs[6] = mul_mat_qX_K_q8_K_AVX512; + m.funcs[7] = mul_mat_qX_K_q8_K_AVX512; + } +#else + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v|| + std::is_same_v|| + std::is_same_v|| + std::is_same_v|| + std::is_same_v) { + m.funcs[0] = mul_mat_qY_K_q8_K_T; + m.funcs[1] = mul_mat_qY_K_q8_K_T; + m.funcs[2] = mul_mat_qY_K_q8_K_T; + m.funcs[3] = mul_mat_qY_K_q8_K_T; + m.funcs[4] = mul_mat_qY_K_q8_K_T; + m.funcs[5] = mul_mat_qY_K_q8_K_T; + m.funcs[6] = mul_mat_qY_K_q8_K_T; + m.funcs[7] = mul_mat_qY_K_q8_K_T; + } else { + m.funcs[0] = mul_mat_qX_K_q8_K_T; + m.funcs[1] = mul_mat_qX_K_q8_K_T; + m.funcs[2] = mul_mat_qX_K_q8_K_T; + m.funcs[3] = mul_mat_qX_K_q8_K_T; + m.funcs[4] = mul_mat_qX_K_q8_K_T; + m.funcs[5] = mul_mat_qX_K_q8_K_T; + m.funcs[6] = mul_mat_qX_K_q8_K_T; + m.funcs[7] = mul_mat_qX_K_q8_K_T; + } +#endif + } +} + +template +void set_mul_mat_f(MulMat& mm) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>; + mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>; + mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>; + mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>; + mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>; +#ifndef __AVX512F__ + mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>; +#endif +} + +bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { + + (void)Ny; + + if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) { + if (ne00 % 4) return false; + } + if (typeA == GGML_TYPE_F16) { + switch (typeB) { + case GGML_TYPE_F16: set_mul_mat_f(mm); break; + case GGML_TYPE_F32: set_mul_mat_f(mm); break; + default: return false; + } + return true; + } + if (typeA == GGML_TYPE_F32) { + switch (typeB) { + case GGML_TYPE_F16: set_mul_mat_f(mm); break; + case GGML_TYPE_F32: set_mul_mat_f(mm); break; + default: return false; + } + return true; + } + + auto expected_typeB = GGML_TYPE_Q8_K; + + switch (typeA) { + case GGML_TYPE_Q2_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ2_TN: + assert (ne00 % QK_K == 0); +#ifdef HAVE_FANCY_SIMD + MulMat::set_functions(mm); +#else + mm.funcs[0] = mul_mat_iq2tn_q8_K<1>; + mm.funcs[1] = mul_mat_iq2tn_q8_K<2>; + mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; + mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; + mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; + //mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; + //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; + //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; +#endif + break; + case GGML_TYPE_Q3_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_Q4_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_Q5_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_Q6_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ4_XS: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ2_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ3_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ4_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ5_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ6_K: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ3_S: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ3_XXS: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ2_S: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ2_XS: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ2_XXS: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; + case GGML_TYPE_IQ1_BN: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>; + mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>; + mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>; + mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>; + mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>; + mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>; + mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>; + mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + expected_typeB = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ2_BN: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>; + mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>; + mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>; + mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>; + mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>; + mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>; + mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>; + mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>; + expected_typeB = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_Q4_0: + assert (ne00 % QK4_0 == 0); + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_0; + break; + case GGML_TYPE_Q4_1: + assert (ne00 % QK4_1 == 0); + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_1; + break; + case GGML_TYPE_Q5_0: + assert (ne00 % QK5_0 == 0); + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_0; + break; + case GGML_TYPE_Q5_1: + assert (ne00 % QK5_1 == 0); + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_1; + break; + case GGML_TYPE_Q8_0: + assert (ne00 % QK8_0 == 0); + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_0; + break; + case GGML_TYPE_IQ4_NL: + assert (ne00 % QK4_NL == 0); + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_0; + break; + + default: + return false; + } + + return ggml_type(typeB) == expected_typeB; +} + +} // namespace + + +#else // __aarch64__ + +namespace { + +template struct Q8 { + + constexpr static int nrc_y = nrc; + + Q8(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); + } + + inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } + inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } + inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } + inline int16x8_t load_bsums8(int iy, int i) const { + auto q8s = vld1q_s16_x2(y[iy][i].bsums); + return vpaddq_s16(q8s.val[0], q8s.val[1]); + } + inline float scale(int iy, int i) const { return y[iy][i].d; } + + const block_q8 * y[nrc_y]; +}; + +template +inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, + const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { + auto mzero = vdupq_n_s32(0); + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1 + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2 + auto p12 = vpaddq_s32(p1, p2); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1 + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2 + auto p34 = vpaddq_s32(p3, p4); + + auto pall = vpaddq_s32(p12, p34); + sumi = vmlaq_s32(sumi, scales.val[j], pall); +} + +template +inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, + const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { + + auto mzero = vdupq_n_s32(0); + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1, + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4, + auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3 + sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5, + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7, + auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7 + sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); +} + +template +inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums8(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s)); + int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s)); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2)); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } +} +template +inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); + int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); + int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); + int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } +} + +struct Scales8 { + uint32_t utmp[4]; + const uint8_t * sc8 = (const uint8_t *)utmp; + template + inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) { + make_q4_scales(x.scales, utmp); + int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8)); + accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin)); + + uint8x8_t scales8 = vld1_u8(sc8); + uint16x8_t scales16 = vmovl_u8(scales8); + int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))), + vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))}; + return scales; + } +}; + +struct Q4bits { + const uint8x16_t m4b = vdupq_n_u8(0xf); + uint8x16x4_t b1, b2; + inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { + b.val[0] = vandq_u8(val[0], m4b); + b.val[2] = vshrq_n_u8(val[0], 4); + b.val[1] = vandq_u8(val[1], m4b); + b.val[3] = vshrq_n_u8(val[1], 4); + } + inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { + b.val[0] = vandq_u8(val[0], m4b); + b.val[1] = vshrq_n_u8(val[0], 4); + b.val[2] = vandq_u8(val[1], m4b); + b.val[3] = vshrq_n_u8(val[1], 4); + } + inline void prepare(const uint8_t * qs) { + auto q4bits = vld1q_u8_x2(qs); + prepare4(b1, q4bits.val); + q4bits = vld1q_u8_x2(qs+32); + prepare4(b2, q4bits.val); + } + inline void prepare_v2(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + prepare4(b1, q4bits.val+0); + prepare4(b2, q4bits.val+2); + } + inline void prepare64(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + b1.val[0] = vandq_u8(q4bits.val[0], m4b); + b1.val[1] = vandq_u8(q4bits.val[1], m4b); + b1.val[2] = vandq_u8(q4bits.val[2], m4b); + b1.val[3] = vandq_u8(q4bits.val[3], m4b); + b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); + b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); + b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); + b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); + } + inline void prepare16(const uint8_t * qs) { + auto q4bits = vld1q_u8_x2(qs); + prepare4_16(b1, q4bits.val); + q4bits = vld1q_u8_x2(qs+32); + prepare4_16(b2, q4bits.val); + } + inline void prepare16_v2(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + prepare4_16(b1, q4bits.val+0); + prepare4_16(b2, q4bits.val+2); + } +}; + +struct Q2bits { + const uint8x16_t m4b = vdupq_n_u8(0x03); + uint8x16x4_t b1, b2; + inline void prepare(const uint8_t * qs) { + auto q2bits = vld1q_u8_x2(qs); + b1.val[0] = vandq_u8(q2bits.val[0], m4b); + b1.val[1] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b1.val[2] = vandq_u8(q2bits.val[0], m4b); + b1.val[3] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b2.val[0] = vandq_u8(q2bits.val[0], m4b); + b2.val[1] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b2.val[2] = vandq_u8(q2bits.val[0], m4b); + b2.val[3] = vandq_u8(q2bits.val[1], m4b); + } +}; + +template +struct BaseDequantizer { + BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} + inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); } + const void * vx; + const block_q * x; + const size_t bx; + const int nrc; +}; + +struct DequantizerQ4K final : public BaseDequantizer { + DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return s8.process_scales_mins(x[i], q8, i, acc); + } + inline void prepare(int i, int j) { + if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); + else bits.prepare(x[i].qs+64*j); + } + + Q4bits bits; + Scales8 s8; + + float d; +}; + +struct HighBit5 { + const uint8x16_t mhb = vdupq_n_u8(0x10); + uint8x16x2_t bits; + inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { + b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb)); + b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb)); + b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb)); + b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb)); + + b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); + b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); + b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); + b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); + + if (do_shift) { + bits.val[0] = vshrq_n_u8(bits.val[0], 4); + bits.val[1] = vshrq_n_u8(bits.val[1], 4); + } + } +}; + +struct HighBit3 { + const uint8x16_t mhb = vdupq_n_u8(0x04); + uint8x16x2_t bits; + inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { + b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); + b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); + b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); + b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); + + b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb)); + b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb)); + b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb)); + b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb)); + + if (do_shift) { + bits.val[0] = vshrq_n_u8(bits.val[0], 4); + bits.val[1] = vshrq_n_u8(bits.val[1], 4); + } + } +}; + +struct DequantizerQ5K final : public BaseDequantizer { + DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + h.bits = vld1q_u8_x2(x[i].qh); + return s8.process_scales_mins(x[i], q8, i, acc); + } + inline void prepare(int i, int j) { + if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); + else bits.prepare(x[i].qs+64*j); + h.apply(bits.b1, bits.b2, j == 0); + } + + Q4bits bits; + HighBit5 h; + Scales8 s8; + + uint8x16x2_t hbits; + + float d; +}; + +inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { + int32x4x4_t scales = { + vmovl_s16(vget_low_s16 (scales16.val[0])), + vmovl_s16(vget_high_s16(scales16.val[0])), + vmovl_s16(vget_low_s16 (scales16.val[1])), + vmovl_s16(vget_high_s16(scales16.val[1])), + }; + return scales; +} + +template +inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) { + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + accum_mins_16(scales16, q8, acc, i, c); + return make_wider(scales16); +} + +struct DequantizerQ6K final : public BaseDequantizer { + DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d); + } + inline void prepare(int i, int j) { + + auto hbits = vld1q_u8_x2(x[i].qh + 32*j); + + bits.prepare64(x[i].ql+64*j); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb)); + + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb)); + + } + + Q4bits bits; + + const uint8x16_t mhb = vdupq_n_u8(0x30); + + float d; +}; + +struct DequantizerQ3K final : public BaseDequantizer { + DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + h.bits = vld1q_u8_x2(x[i].hmask); + mask = vdupq_n_u8(0x01); + const uint16_t * sc16 = (const uint16_t *)x[i].scales; + uint32_t aux0 = sc16[0] | (sc16[1] << 16); + uint32_t aux1 = sc16[2] | (sc16[3] << 16); + uint32_t aux2 = sc16[4] | (sc16[5] << 16); + aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); + aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); + aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); + aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); + auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); + if (nrc > 1) { + return process_scales_mins_16(scales8, q8, acc, i, -4.f*d); + } + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + return make_wider(scales16); + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + if (nrc > 1) { + h.apply(bits.b1, bits.b2, j == 0); + } else { + auto minus4 = vdupq_n_u8(0xfc); + auto zero = vdupq_n_u8(0); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + } + } + + uint32_t aux32[4]; + + Q2bits bits; + + uint8x16_t mask; + HighBit3 h; + + float d; +}; + +struct DequantizerQ2K final : public BaseDequantizer { + DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + template + inline void process_scales(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales_and_mins = vld1q_u8(x[i].scales); + auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4)); + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(mins8)); + scales16.val[1] = vmovl_s8(vget_high_s8(mins8)); + accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin)); + + scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf)); + } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + process_scales(i, q8, acc); + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8))); + scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8))); + return make_wider(scales16); + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto m1 = vdupq_n_u8(1); + auto shuffle = vdupq_n_u8(8*j); + bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + } + + uint32_t aux32[4]; + + uint8x16_t scales8; + + Q2bits bits; + + float d; +}; + +// ============================= i-quants + +inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { + int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))}; + return make_wider(scales16); +} + +struct Scale16Extra { + template + static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, + const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) { + uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); + e8 = vceqq_u8(vandq_u8(e8, emask), emask); + e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff); + int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)), + vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))}; + accum_mins_16(extra16, q8, acc, i, d); + return make_wider_8(scales8); + } + + constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040}; + constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09}; +}; + +// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because +// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to +// be signed or unsigned. As the Q8_K quants are signed, we need to have the +// iq4_s quants also signed. We can only use unsigned values in k-quants +// because they are all within the valid int8_t range. +struct DequantizerIQ4K final : public BaseDequantizer { + DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_l, x[i].scales_h), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + const uint32_t * aux32 = (const uint32_t *)scales_h; + uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2}; + uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30)); + int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff)); + return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32)); + } + + Q4bits bits; + const int8x16_t values; + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + + float d; +}; + +struct DequantizerIQ5K final : public BaseDequantizer { + DequantizerIQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + hbits = vld1q_u8_x2(x[i].qh); // hbits.val[0] holds 0....15, 32...47, 64...79, 96...111, 128...143, 160...175, 192...207, 224...239 + // hbits.val[1] holds 16...31, 48...63, 80...95, 112..127, 144...159, 176...191, 208...223, 240...255 + return Scale16Extra::new_block(i, d, x[i].extra, 2, make_scales(x[i].scales_l, x[i].scales_h), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+64*j); + if (j == 1) { + for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4); + } + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + const uint32_t * aux32 = (const uint32_t *)scales_h; + uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2}; + uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30)); + int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff)); + return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32)); + } + + Q4bits bits; + const int8x16x2_t values; + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const uint8x16_t hm = vdupq_n_u8(0x10); + uint8x16x2_t hbits; + + float d; +}; + +struct DequantizerIQ6K final : public BaseDequantizer { + DequantizerIQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x4(iq6nl_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 1, vld1q_s8(x[i].scales), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+64*j); + auto hbits = vld1q_u8_x2(x[i].qh + 32*j); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hm)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hm)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), hm)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), hm)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]); + } + } + + Q4bits bits; + const int8x16x4_t values; + const uint8x16_t hm = vdupq_n_u8(0x30); + + float d; +}; + +struct DequantizerIQ2K final : public BaseDequantizer { + DequantizerIQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 5, make_scales(x[i].scales), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(const uint8_t * scales_l) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(-15)); + return vqtbl1q_s8(scales, hshuff); + } + + Q2bits bits; + const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1)); + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + + float d; +}; + +struct DequantizerIQ3K final : public BaseDequantizer { + DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + if (j == 0) { + hbits = vld1q_u8_x2(x[i].qh); + } + else { + hbits.val[0] = vshrq_n_u8(hbits.val[0], 4); + hbits.val[1] = vshrq_n_u8(hbits.val[1], 4); + } + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1)); + uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask); + signs = vorrq_u8(signs, vdupq_n_u8(1)); + // scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + // signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15 + scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle))); + return vqtbl1q_s8(scales, hshuff); + } + inline static uint8x16_t load_sign_shuffle() { + static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + return vld1q_u8(k_shuff); + } + + Q2bits bits; + uint8x16x2_t hbits; + const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1)); + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const uint8x16_t hmask = vdupq_n_u8(4); + const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010}); + const uint8x16_t sign_shuffle = load_sign_shuffle(); + + float d; +}; + +struct DequantizerIQ4XS final : public BaseDequantizer { + + static int8x16_t load_values() { + static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + return vld1q_s8(iq4nl_values); + } + + DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + d = GGML_FP16_TO_FP32(x[i].d); + const uint16_t scales_h = x[i].scales_h; + const uint16_t * scales_l = (const uint16_t *)x[i].scales_l; + aux32[0] = scales_l[0] | (scales_l[1] << 16); + aux32[1] = aux32[0] >> 4; + // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7 + uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf)); + uint16_t * aux16 = (uint16_t *)aux32; + aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2; + // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7 + uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30)); + int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32)); + // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7 + scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff)); + int16x8_t scales16 = vmovl_s8(scales8); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + //if (nrc == 1) { + // bits.prepare16_v2(x[i].qs+64*j); + //} else { + // bits.prepare16(x[i].qs+64*j); + //} + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k])); + bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k])); + } + } + + Q4bits bits; + const int8x16_t values; + uint32_t aux32[2]; + + constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602}; + + float d; +}; + +struct SimpleBits { + uint8x16x4_t b1; + uint8x16x4_t b2; +}; + +inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { + int32x4x2_t scales; + scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1))); + scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1))); + return scales; +} + +inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) { + auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127)))); + auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127)))); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1)); + b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); +} + +struct DequantizerIQ2XXS final : public BaseDequantizer { + DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + + auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); + data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 + data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 + data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 + data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 + + return prepare_scales_8(data.val[1], data.val[3]); + } + + static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { + b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + apply_signs_2(b, signs, sidx); + } + + inline void prepare(int /*i*/, int j) { + const uint8_t * idx = (const uint8_t *)(data.val + 2*j); + const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1); + prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4; + prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4; + prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4; + prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); + } + + uint32x4x4_t data; + SimpleBits bits; + + float d; +}; + +inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { + auto aux = vld1_u8(sc); + auto scales_l = vand_u8(aux, vdup_n_u8(0xf)); + auto scales_h = vshr_n_u8(aux, 4); + auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + + auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1))); + int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; + return make_wider(scales16); +} + +struct DequantizerIQ2XS final : public BaseDequantizer { + DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return prepare_4bit_scales16(x[i].scales); + } + + inline static uint8x16_t make1(const uint16_t * qs) { + auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); + auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); + return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s)); + } + + inline static void make4(const uint16_t * qs, uint8x16_t * b) { + b[0] = make1(qs + 0); + b[1] = make1(qs + 2); + b[2] = make1(qs + 4); + b[3] = make1(qs + 6); + } + + inline void prepare(int i, int j) { + make4(x[i].qs + 16*j + 0, bits.b1.val); + make4(x[i].qs + 16*j + 8, bits.b2.val); + } + + SimpleBits bits; + + float d; + +}; + +struct SignHelper { + + inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); } + + inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) { + auto aux = vqtbl1q_u8(signs16, shuffle); + auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); + shuffle = vaddq_u8(shuffle, step); + } + + const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + const uint8x16_t m1 = vdupq_n_u8(1); + const uint8x16_t step = vdupq_n_u8(2); + uint8x16_t shuffle; +}; + +struct DequantizerIQ2S final : public BaseDequantizer { + DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return prepare_4bit_scales16(x[i].scales); + } + + static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + uint32_t aux32[2]; + const uint16_t * aux16 = (const uint16_t *)aux32; + for (int k = 0; k < 2; ++k) { + aux32[1] = (qh[k] << 4) | (qh[k] << 18); + aux32[0] = (aux32[1] << 4) & 0x03000300; + aux32[1] &= 0x03000300; + b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), + vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); + sh.apply_signs_1(b+2*k+0, signs16); + + b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), + vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); + sh.apply_signs_1(b+2*k+1, signs16); + } + } + + inline void prepare(int i, int j) { + + const auto * qs = x[i].qs + 16*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(qs + QK_K/8); + + sh.init(); + make4(sh, signs16, qs+0, qh+0, bits.b1.val); + make4(sh, signs16, qs+8, qh+2, bits.b2.val); + } + + SimpleBits bits; + SignHelper sh; + + float d; + +}; + +struct DequantizerIQ3XXS final : public BaseDequantizer { + DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4)); + return prepare_scales_8(gas.val[0], gas.val[1]); + } + + inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) { + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); + apply_signs_2(b, keven_signs, sidx); + } + inline void prepare(int i, int j) { + const auto * q3 = x[i].qs + 32*j; + const auto * signs = (const uint32_t *)(gas.val + j); + make2(q3, signs[0], bits.b1.val + 0); q3 += 8; + make2(q3, signs[1], bits.b1.val + 2); q3 += 8; + make2(q3, signs[2], bits.b2.val + 0); q3 += 8; + make2(q3, signs[3], bits.b2.val + 2); + } + + SimpleBits bits; + uint32x4x2_t gas; + + float d; + +}; + +struct DequantizerIQ3S final : public BaseDequantizer { + DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = GGML_FP16_TO_FP32(x[i].d); + uint32_t scales32[2]; + std::memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 + scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); + auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); + int32x4x2_t scales; + scales.val[0] = vmovl_s16(vget_low_s16(scales16)); + scales.val[1] = vmovl_s16(vget_high_s16(scales16)); + return scales; + } + + static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh, + const int8x16_t& hshift, uint8x16_t * b) { + auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); + const uint16_t * idx = (const uint16_t *)&vindex; + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); + sh.apply_signs_1(b+0, signs16); + sh.apply_signs_1(b+1, signs16); + } + static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, + const int8x16_t& hshift, uint8x16_t * b) { + auto idx_l = vld1q_u8(qs); + make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); + make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); + } + + inline void prepare(int i, int j) { + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + const auto hshift = vld1q_s16(k_shift); + + const auto * qs = x[i].qs + 32*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(x[i].signs + 16*j); + + sh.init(); + make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val); + make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val); + } + + SimpleBits bits; + SignHelper sh; + uint32x4x2_t gas; + + float d; + +}; + +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + //template + //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + // d = GGML_FP16_TO_FP32(x[i].d); + //} + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + template + inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto q8b_1 = q8.load_quants(0, i, 4*j+0); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(0, i, 4*j+1); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + q8b_1 = q8.load_quants(0, i, 4*j+2); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]); + + q8b_2 = q8.load_quants(0, i, 4*j+3); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]); + } + + IQK_ALWAYS_INLINE void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + auto m1 = vdupq_n_s8(1); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1); + bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1); + } + } + + Q2bits bits; + + float d; +}; + +template +void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + DequantizerIQ2TN deq(vx, bx, nrc_y); + float32x4_t acc[nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + deq.new_block(i); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} +void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<1, block_q8_K> q8(info); + + DequantizerIQ2TN deq(vx, bx, 1); + + auto m1 = vdup_n_s16(-1); + float32x4_t acc[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[2] = {}; + deq.new_block(i); + auto bsums = q8.load_bsums(0, i); + bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]); + sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1); + sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1); + deq.bits.prepare(deq.x[i].qs); + deq.compute1(q8, i, 0, sumi); + deq.bits.prepare(deq.x[i].qs+32); + deq.compute1(q8, i, 1, sumi); + + auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); + if (i > 0) { + acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + } else { + acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd); + } + + } + + acc[0] = vaddq_f32(acc[0], acc[1]); + info.store(ix, 0, vaddvq_f32(acc[0])); + } +} + + +template +void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { + deq.process_scales(i, q8, acc); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + } else { + if constexpr (Dequantizer::num_blocks() == 8) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else if constexpr (Dequantizer::num_blocks() == 16) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else { + GGML_ASSERT(false); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + +// =========================================== Legacy quants + +template +inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) { + for (int k = 0; k < 4; ++k) aux[k] = x[k].d; + return vld1_f16((const float16_t *)aux); +} + +template +inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) { + if constexpr (std::is_same_v) { + for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; } + } else { + for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; } + } + return vld1q_f16((const float16_t *)aux); +} + +struct Q4LegacyBits { + template + inline void prepare(const Block * x) { + for (int i = 0; i < 4; ++i) { + auto q4bits = vld1q_u8(x[i].qs); + b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); + b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); + } + } + inline void prepare1(const uint8_t * qs, int8x16_t * q) const { + auto q4bits = vld1q_u8(qs); + q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); + q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); + } + inline void prepare1(const uint8_t * qs) { + prepare1(qs, b); + } + const uint8x16_t m4b = vdupq_n_u8(0xf); + int8x16_t b[8]; +}; + +// One would think this commented out version would do better than the one below +// because it offers more opportunities to execute instructions in parallel. +// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers +// cannot it just do the sequential version below on its own? +//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { +// const auto q8b_1 = vld1q_s8_x2(qs + 0); +// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]); +// const auto q8b_2 = vld1q_s8_x2(qs + 32); +// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]); +// auto p1234 = vpaddq_s32(p12, p34); +// const auto q8b_3 = vld1q_s8_x2(qs + 64); +// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]); +// const auto q8b_4 = vld1q_s8_x2(qs + 96); +// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]); +// return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); +//} + +inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { + auto q8b = vld1q_s8_x2(qs + 0); + auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 32); + auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]); + auto p1234 = vpaddq_s32(p12, p34); + q8b = vld1q_s8_x2(qs + 64); + auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 96); + auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]); + return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); +} + +template struct Q80 { + + constexpr static int nrc_y = nrc; + + Q80(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); + } + + inline const int8_t * quant_data(int iy, int i) const { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; + return y4->qs; + } + + inline float16x4_t load_scales(int iy, int i) const { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; + return vld1_f16((const float16_t *)y4->d); + } + + template + inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const { + auto qx_scales = deq.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + sc16[iy] = vmul_f16(qx_scales, q8_scales); + } + } + + template + inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { + deq.prepare1(i); + float d = GGML_FP16_TO_FP32(deq.x[i].d); + for (int iy = 0; iy < nrc; ++iy) { + auto q8b = vld1q_s8_x2(y[iy][i].qs); + auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); + } + } + + const block_q8_0 * y[nrc_y]; +}; + +template struct Q81 { + + constexpr static int nrc_y = nrc; + + Q81(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy); + } + + inline const int8_t * quant_data(int iy, int i) const { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; + return y4->qs; + } + + inline float16x8_t load_scales(int iy, int i) const { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; + return vld1q_f16((const float16_t *)y4->d); + } + + template + inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const { + auto qx_scales = deq.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales)); + acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m)); + sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales)); + } + } + + template + inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { + deq.prepare1(i); + float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); + for (int iy = 0; iy < nrc; ++iy) { + auto q8b = vld1q_s8_x2(y[iy][i].qs); + auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); + acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s))); + } + } + + const block_q8_1 * y[nrc_y]; +}; + +template +struct BaseLegacyDequantizer { + + BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {} + + inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); } + + Q4LegacyBits bits; + + const void * vx; + const block_q * x; + size_t bx; +}; + +struct DequantizerQ40 final : public BaseLegacyDequantizer { + + DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + q[0] = vaddq_s8(q[0], m8); + q[1] = vaddq_s8(q[1], m8); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m8 = vdupq_n_s8(-8); + //ggml_half aux[4]; +}; + +struct DequantizerIQ4NL final : public BaseLegacyDequantizer { + + DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + q[0] = vqtbl1q_s8(values, q[0]); + q[1] = vqtbl1q_s8(values, q[1]); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + static int8x16_t load_values() { + static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + return vld1q_s8(iq4nl_values); + } + + const int8x16_t values = load_values(); +}; + +struct DequantizerQ41 : public BaseLegacyDequantizer { + + DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.prepare1(x[i].qs); + } + + inline float16x8_t new_block(int i) { + uint32_t aux32[4]; + const uint32_t * s32 = (const uint32_t *)&x[4*i].d; + for (int k = 0; k < 4; ++k) { + aux32[k] = *s32; s32 += sizeof(block_q4_1)/4; + bits.prepare1(x[4*i+k].qs, bits.b + 2*k); + } + return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); + } + // Leaving this commented out attempt to be reminded that I already tried this. + // It has basically the same performance as the version above. + //inline float16x8_t new_block(int i) { + // uint32x4_t scales = {}; + // const block_q4_1 * xi = x + 4*i; + // const uint32_t * s32 = (const uint32_t *)&xi->d; + // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[0].qs, bits.b + 0); + // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[1].qs, bits.b + 2); + // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[2].qs, bits.b + 4); + // scales = vsetq_lane_u32(*s32, scales, 3); + // bits.prepare1(xi[3].qs, bits.b + 6); + // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle))); + //} + + const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; +}; + +struct HighBit5Legacy { + inline uint8x16_t to_bytes(const uint8_t * qh) const { + uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); + return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask)); + } + inline uint8x16_t to_negated_bytes(const uint8_t * qh) const { + uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); + return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0)); + } + const uint64x2_t mask = vdupq_n_u64(0x8040201008040201); + const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); +}; + +struct DequantizerQ50 final : public BaseLegacyDequantizer { + + DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh = x[i].qh; + q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); + q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + HighBit5Legacy hbits; + + const uint8x16_t mh = vdupq_n_u8(0xf0); + +}; + +struct DequantizerQ80 final : public BaseLegacyDequantizer { + + DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.b[0] = vld1q_s8(x[i].qs); + bits.b[1] = vld1q_s8(x[i].qs+16); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs); + bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16); + } + return vld1_f16((const float16_t *)aux); + } + +}; + +struct DequantizerQ51 final : public BaseLegacyDequantizer { + + DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh = x[i].qh; + q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); + q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); + } + inline void prepare1(int i) { + bits.prepare1(x[i].qs, bits.b); + } + + inline float16x8_t new_block(int i) { + uint32_t aux32[4]; + const uint32_t * s32 = (const uint32_t *)&x[4*i].d; + for (int k = 0; k < 4; ++k) { + aux32[k] = *s32; s32 += sizeof(block_q5_1)/4; + prepare1(4*i+k, bits.b + 2*k); + } + return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); + } + + HighBit5Legacy hbits; + + const uint8x16_t mh = vdupq_n_u8(0x10); + const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; + +}; + +template +inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i)); + auto scale = vcvt_f32_f16(sc16[iy]); + acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall)); + } +} + +template +inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[Q8::nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[Q8::nrc_y]; + for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb/4; ++i) { + q8.process_scales(i, deq, sc16, acc); + sum_4(i, deq, q8, sc16, acc); + } + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq, acc); + } + + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + +template +inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq1.new_row(ix); + deq2.new_row(ix); + + float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) }; + + for (int i = 0; i < nb/8; ++i) { + q8.process_scales(2*i+0, deq1, sc16+0, acc+0); + q8.process_scales(2*i+1, deq2, sc16+1, acc+1); + sum_4(2*i+0, deq1, q8, sc16+0, acc+0); + sum_4(2*i+1, deq2, q8, sc16+1, acc+1); + } + for (int i = 2*(nb/8); i < nb/4; ++i) { + q8.process_scales(i, deq1, sc16, acc); + sum_4(i, deq1, q8, sc16, acc); + } + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq1, acc); + } + + info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); + } +} + +template +static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q81 q8(info); + if constexpr (nrc_y == 1) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } +} + +template +static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q80 q8(info); + if constexpr (nrc_y == 1) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } +} + +template +static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + Q81<1> q8(info); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); +} + +template +static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + Q80<1> q8(info); + mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); +} + +struct QF16Base { + constexpr static int k_step = 8; + using Data = float16x8_t; + using Acc = float16x8_t; + static inline Data load(const __fp16 * x) { return vld1q_f16(x); } + static inline Data load4(const __fp16 * x) { return vcombine_f16(vld1_f16(x), vdup_n_f16(0)); } + static inline Acc acc(Acc prev, const Data& y, const Data& x) { + return vfmaq_f16(prev, y, x); + } + static inline Acc acc_first(const Data& y, const Data& x) { + return vmulq_f16(y, x); + } + //constexpr static int k_step = 16; + //using Data = float16x8x2_t; + //static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); } + //static inline Acc acc(Acc prev, const Data& y, const Data& x) { + // return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]); + //} + //static inline Acc acc_first(const Data& y, const Data& x) { + // return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]); + //} + static inline float hsum(Acc acc) { + float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc))); + return vaddvq_f32(sum); + } +}; +template struct QF16 final : public QF16Base { + constexpr static int nrc_y = nrc; + QF16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy); + } + QF16(const char * cx, size_t bx) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx); + } + IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } + IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); } + IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); } + const __fp16 * y[nrc_y]; +}; + +template +IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + assert(n%QF16Base::k_step == 0); + int nb = n/QF16Base::k_step; + QF16 y(info); + QF16 x(cx + ix0*bx, bx); + QF16Base::Data xv[nrc_x]; + QF16Base::Acc acc[nrc_x*nrc_y]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QF16Base::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + } + } + if constexpr (!is_multiple_of_k_step) { + int nb4 = n/4; + for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) { + yv = y.load_tail(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + xv[ix] = x.load_tail(ix, i); + acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < nrc_y; ++iy) { + yv = y.load_tail(iy, i); + for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix])); +} + +template +void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%4 == 0); + constexpr int k_nx = 5; + const char * cx = (const char *)vx; + if (n%QF16Base::k_step == 0) { + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_f16_f16_NxN(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 2: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 3: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 4: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + } + } else { + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + mul_mat_f16_f16_NxN(n, cx, bx, ix*k_nx, info); + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + int nx = nrc_x - last_x; + switch (nx) { + case 1: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 2: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 3: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + case 4: mul_mat_f16_f16_NxN(n, cx, bx, last_x, info); break; + } + } +} + +template +IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + assert(n%QF16Base::k_step == 0); + int nb = n/QF16Base::k_step; + QF16<1> y(info); + QF16 x(cx + ix0*bx, bx); + QF16Base::Acc acc[4*nrc_x]; + auto yv = y.loadx(0, 0); + for (int ix = 0; ix < nrc_x; ++ix) { + for (int k = 0; k < 4; ++k) { + auto xv = x.load1(ix, k); + acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv); + } + } + for (int i = 1; i < nb/4; ++i) { + yv = y.loadx(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + for (int k = 0; k < 4; ++k) { + auto xv = x.load1(ix, 4*i+k); + acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv); + } + } + } + for (int i = 4*(nb/4); i < nb; ++i) { + auto yv1 = y.load1(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + auto xv1 = x.load1(ix, i); + acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1); + } + } + if constexpr (!is_multiple_of_k_step) { + int nb4 = n/4; + for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) { + auto yv1 = y.load_tail(0, i); + for (int ix = 0; ix < nrc_x; ++ix) { + auto xv1 = x.load_tail(ix, i); + acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1); + } + } + } + for (int ix = 0; ix < nrc_x; ++ix) { + auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]); + auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]); + info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2))); + } +} + +// At least on my M2-Max the version below, which dows the multiplication row-by-row, is faster. +// But let's keep this version commented out for now. +//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +// GGML_ASSERT(n%4 == 0); +// constexpr int k_nx = 2; +// const char * cx = (const char *)vx; +// if (n%QF16Base::k_step == 0) { +// for (int ix = 0; ix < nrc_x/k_nx; ++ix) { +// mul_mat_f16_f16_Nx1(n, cx, bx, ix*k_nx, info); +// } +// int last_x = k_nx*(nrc_x/k_nx); +// if (last_x == nrc_x) return; +// int nx = nrc_x - last_x; +// switch (nx) { +// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break; +// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break; +// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break; +// } +// } else { +// for (int ix = 0; ix < nrc_x/k_nx; ++ix) { +// mul_mat_f16_f16_Nx1(n, cx, bx, ix*k_nx, info); +// } +// int last_x = k_nx*(nrc_x/k_nx); +// if (last_x == nrc_x) return; +// int nx = nrc_x - last_x; +// switch (nx) { +// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break; +// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break; +// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break; +// } +// } +//} + +void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%4 == 0); + const char * cx = (const char *)vx; + if (n%QF16Base::k_step == 0) { + for (int ix = 0; ix < nrc_x; ++ix) { + mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info); + } + } else { + for (int ix = 0; ix < nrc_x; ++ix) { + mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info); + } + } +} + +template struct Q8_K64 { + + constexpr static int nrc_y = nrc; + + Q8_K64(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 4); + } + } + + inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); } + inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); } + + float d[4*nrc_y]; + const int8_t * y[nrc_y]; +}; + +struct DequantizerIQ1BN { + const uint8x16_t m1 = vdupq_n_u8(1); + + static inline uint8x16x4_t load_shuffles() { + static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12, + 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12}; + return vld1q_u8_x4(data); + } + static inline uint8x16x4_t load_mult() { + static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3}; + return vld1q_u8_x4(data); + } + const uint8x16x4_t shuff = load_shuffles(); + const uint8x16x4_t mult = load_mult(); + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const { + auto data = vld1q_u8((const uint8_t *)x); + for (int k = 0; k < 4; ++k) { + auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]); + val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6); + v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1); + } + } +}; + +template +static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + + Q8_K64 q8(info); + DequantizerIQ1BN deq; + + int32x4_t accd[nrc_y]; + int8x16x4_t v1, v2; + + const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); + + for (int ix = 0; ix < nrc_x; ++ix) { + + x = (const block_iq1_bn *)((const char *)vx + ix*bx); + + if constexpr (nrc_y == 1) { + int32x4_t acc[4] = {}; + for (int i = 0; i < nb/2; ++i) { + deq.prepare_iq1bn_quants(x+2*i+0, v1); + auto q = q8.load_quants64(0, i, 0); + for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); + deq.prepare_iq1bn_quants(x+2*i+1, v2); + q = q8.load_quants64(0, i, 1); + for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]); + } + accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3])); + } + else { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0); + + for (int i = 0; i < nb/2; ++i) { + + deq.prepare_iq1bn_quants(x+2*i+0, v1); + deq.prepare_iq1bn_quants(x+2*i+1, v2); + + for (int iy = 0; iy < nrc_y; ++iy) { + auto q = q8.load_quants(iy, i, 0); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i, 1); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); + q = q8.load_quants(iy, i, 2); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]); + q = q8.load_quants(iy, i, 3); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]); + } + } + } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare_iq1bn_quants(x+i, v1); + if constexpr (nrc_y == 1) { + auto q = q8.load_quants(0, i/2, 0); + for (int j = 0; j < 4; ++j) { + accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto q = q8.load_quants(iy, i/2, 0); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i/2, 1); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + } + + } +} + +template +static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + + Q8_K64 q8(info); + + int32x4_t accd[nrc_y]; + + const auto m1 = vdupq_n_u8(1); + const auto mask2 = vdupq_n_s8(3); + + for (int ix = 0; ix < nrc_x; ++ix) { + + const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); + + if constexpr (nrc_y == 1) { + int8x16x4_t v1; + int32x4_t acc[4] = {}; + for (int i = 0; i < nb/2; ++i) { + for (int j = 0; j < 2; ++j) { + auto q = q8.load_quants64(0, i, j); + auto q2bits = vld1q_u8(x[2*i+j].qs); + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]); + acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]); + acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]); + acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]); + } + } + accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3])); + } else { + int8x16x4_t v1, v2; + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0); + for (int i = 0; i < nb/2; ++i) { + auto q2bits = vld1q_u8(x[2*i+0].qs); + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + q2bits = vld1q_u8(x[2*i+1].qs); + v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + for (int iy = 0; iy < nrc_y; ++iy) { + auto q = q8.load_quants(iy, i, 0); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i, 1); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); + q = q8.load_quants(iy, i, 2); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]); + q = q8.load_quants(iy, i, 3); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]); + } + } + } + int i = 2*(nb/2); + if (i < nb) { + auto q2bits = vld1q_u8(x[i].qs); + int8x16x4_t v1; + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + for (int iy = 0; iy < nrc_y; ++iy) { + auto q = q8.load_quants(iy, i/2, 0); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i/2, 1); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + } + } +} + +template void MulMat::set_functions(MulMat& m) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_0_q8_0; + m.funcs[1] = mul_mat_qX_0_q8_0; + m.funcs[2] = mul_mat_qX_0_q8_0; + m.funcs[3] = mul_mat_qX_0_q8_0; + m.funcs[4] = mul_mat_qX_0_q8_0; + m.funcs[5] = mul_mat_qX_0_q8_0; + m.funcs[6] = mul_mat_qX_0_q8_0; + m.funcs[7] = mul_mat_qX_0_q8_0; + } + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_1_q8_1; + m.funcs[1] = mul_mat_qX_1_q8_1; + m.funcs[2] = mul_mat_qX_1_q8_1; + m.funcs[3] = mul_mat_qX_1_q8_1; + m.funcs[4] = mul_mat_qX_1_q8_1; + m.funcs[5] = mul_mat_qX_1_q8_1; + m.funcs[6] = mul_mat_qX_1_q8_1; + m.funcs[7] = mul_mat_qX_1_q8_1; + } + else { + m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; + m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; + m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; + m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; + m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; + m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; + m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; + m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; + } +} + +bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { + + if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) { + if (ne00%4) return false; + for (auto& f : m.funcs) f = nullptr; + m.funcs[0] = mul_mat_f16_f16_1; + m.funcs[1] = mul_mat_f16_f16_T<2>; + m.funcs[2] = mul_mat_f16_f16_T<3>; + m.funcs[3] = mul_mat_f16_f16_T<4>; + m.funcs[4] = mul_mat_f16_f16_T<5>; + return true; + } + + auto expected_Btype = GGML_TYPE_Q8_K; + + switch (typeA) { + case GGML_TYPE_Q2_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ2_TN: + //MulMat::set_functions(m); + m.funcs[0] = mul_mat_iq2tn_K_q8_K_1; + m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>; + m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>; + m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>; + m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>; + m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>; + m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>; + m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>; + break; + case GGML_TYPE_Q3_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q4_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q5_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q6_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ4_XS: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ4_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ5_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ6_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ2_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ3_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ2_XXS: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ2_XS: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ2_S: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ3_XXS: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ3_S: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ1_BN: + m.funcs[0] = mul_mat_iq1bn_q8_K64<1>; + m.funcs[1] = mul_mat_iq1bn_q8_K64<2>; + m.funcs[2] = mul_mat_iq1bn_q8_K64<3>; + m.funcs[3] = mul_mat_iq1bn_q8_K64<4>; + m.funcs[4] = mul_mat_iq1bn_q8_K64<5>; + m.funcs[5] = mul_mat_iq1bn_q8_K64<6>; + m.funcs[6] = mul_mat_iq1bn_q8_K64<7>; + m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + expected_Btype = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ2_BN: + m.funcs[0] = mul_mat_iq2bn_q8_K64<1>; + m.funcs[1] = mul_mat_iq2bn_q8_K64<2>; + m.funcs[2] = mul_mat_iq2bn_q8_K64<3>; + m.funcs[3] = mul_mat_iq2bn_q8_K64<4>; + m.funcs[4] = mul_mat_iq2bn_q8_K64<5>; + m.funcs[5] = mul_mat_iq2bn_q8_K64<6>; + m.funcs[6] = mul_mat_iq2bn_q8_K64<7>; + m.funcs[7] = mul_mat_iq2bn_q8_K64<8>; + expected_Btype = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_Q4_0: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_0; + break; + case GGML_TYPE_Q4_1: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_1; + break; + case GGML_TYPE_Q5_0: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_0; + break; + case GGML_TYPE_Q5_1: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_1; + break; + case GGML_TYPE_Q8_0: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_0; + break; + case GGML_TYPE_IQ4_NL: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_0; + break; + default: + return false; + } + + return typeB == expected_Btype; +} + +} + +#endif // __aarch64__ + +#else // IQK_IMPLEMENT + +bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { + return false; +} + +bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, + const void *, int, int) { + return false; +} + +#endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h new file mode 100644 index 0000000000000..6bed5f5afd030 --- /dev/null +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -0,0 +1,27 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#pragma once +#include +#include +#ifdef __cplusplus +extern "C" { +#endif + +bool iqk_mul_mat(long Nx, long Ny, long ne00, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, int ith, int nth); + +bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); + + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp new file mode 100644 index 0000000000000..730de8c9f4df4 --- /dev/null +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -0,0 +1,1984 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#if GGML_USE_IQK_MULMAT +#include "iqk_mul_mat.h" +#endif +#include "ggml-quants.h" +#include "ggml-impl.h" +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "iqk_quantize.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +struct IQ1BNQuantizer { + int8_t L[QK_IQ1BN]; + void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix); + void quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix); + static inline float row_max(int n_per_row, const float * src) { + float max_in_row = 0; + for (int j = 0; j < n_per_row; ++j) { + float ax = fabsf(src[j]); + max_in_row = std::max(max_in_row, ax); + } + return max_in_row; + } + // The Makefile has issues dwaling with this? + //static constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + static const uint8_t k_mult[5]; +}; + +const uint8_t IQ1BNQuantizer::k_mult[5] = {81, 27, 9, 3, 1}; + +void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) { + + static const int k_nb[6] = {1, 3, 9, 27, 81, 243}; + (void)imatrix; + + const int nblock = n_per_row/QK_IQ1BN; + + for (int ib = 0; ib < nblock; ++ib) { + std::memset(&y[ib], 0, sizeof(block_iq1_bn)); + auto xb = src + ib*QK_IQ1BN; + int v13 = 0; + for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { + for (int k = 0; k < 3; ++k) { + int idx = 0; + for (int j = 0; j < 5; ++j) { + float v = xb[16*i16 + 5*k + j]; + int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + idx += k_nb[j]*q; + } + idx = (256*idx + k_nb[5] - 1)/k_nb[5]; + y[ib].ql[3*i16 + k] = idx; + } + float v = xb[16*i16 + 15]; + int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + v13 += k_nb[i16]*q; + } + y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5]; + } +} + +void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) { + + (void)imatrix; + + const int nblock = n_per_row/QK_IQ1BN; + + constexpr int Nj = QK_IQ1BN/4; + + for (int ib = 0; ib < nblock; ++ib) { + auto xb = src + QK_IQ1BN*ib; + for (int j = 0; j < QK_IQ1BN; ++j) { + L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2; + } + for (int j = 0; j < Nj; ++j) { + y[ib].qs[j] = L[j] | (L[j + Nj] << 2) | (L[j + 2*Nj] << 4) | (L[j + 3*Nj] << 6); + } + } +} + +} + +size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + IQ1BNQuantizer iq1bn; + int nblock = n_per_row/QK_IQ1BN; + block_iq1_bn * y = (block_iq1_bn *)dst; + for (int row = 0; row < nrows; ++row) { + iq1bn.quantize_one_row_1bn(src + row*n_per_row, y, n_per_row, imatrix); + y += nblock; + } + return sizeof(block_iq1_bn)*nblock*nrows; +} + +void quantize_row_iq1_bn_ref(const float * x, block_iq1_bn * y, int64_t k) { + quantize_iq1_bn(x, y, 1, k, nullptr); +} + +void quantize_row_iq1_bn(const float * x, void * y, int64_t k) { + quantize_iq1_bn(x, y, 1, k, nullptr); +} + +void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { + assert(k%QK_IQ1BN == 0); + int nblock = k / QK_IQ1BN; + + for (int i = 0; i < nblock; ++i) { + uint8_t extra = x[i].extra; + auto ql = x[i].ql; + for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { + for (int k = 0; k < 3; ++k) { + for (int j = 0; j < 5; ++j) { + uint8_t v = ql[k]*IQ1BNQuantizer::k_mult[j]; + int8_t vs = ((v + (v >> 1)) >> 7); + *y++ = vs - 1; + } + } + ql += 3; + uint8_t v = extra*IQ1BNQuantizer::k_mult[i16]; + int8_t vs = ((v + (v >> 1)) >> 7); + *y++ = vs - 1; + } + } +} + +size_t quantize_iq2_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + IQ1BNQuantizer iq1bn; + int nblock = n_per_row/QK_IQ1BN; + block_iq2_bn * y = (block_iq2_bn *)dst; + for (int row = 0; row < nrows; ++row) { + iq1bn.quantize_one_row_2bn(src + row*n_per_row, y, n_per_row, imatrix); + y += nblock; + } + return sizeof(block_iq2_bn)*nblock*nrows; +} + +void quantize_row_iq2_bn_ref(const float * x, block_iq2_bn * y, int64_t k) { + quantize_iq2_bn(x, y, 1, k, nullptr); +} + +void quantize_row_iq2_bn(const float * x, void * y, int64_t k) { + quantize_iq2_bn(x, y, 1, k, nullptr); +} + +void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) { + assert(k%QK_IQ1BN == 0); + int nblock = k / QK_IQ1BN; + + auto d1 = 1.f, d2 = 0.25f, d3 = d2*0.25f, d4 = d3*0.25f; + auto m = -1.f; + constexpr int Nj = QK_IQ1BN/4; + for (int i = 0; i < nblock; ++i) { + for (int j = 0; j < Nj; ++j) { + y[j+ 0] = d1*(x[i].qs[j] & 0x03) + m; + y[j+1*Nj] = d2*(x[i].qs[j] & 0x0c) + m; + y[j+2*Nj] = d3*(x[i].qs[j] & 0x30) + m; + y[j+3*Nj] = d4*(x[i].qs[j] & 0xc0) + m; + } + y += QK_IQ1BN; + } +} + +namespace { +inline int8_t iq1bn_dequant(uint8_t q, int i) { + uint8_t v = IQ1BNQuantizer::k_mult[i]*q; + //int8_t vs = (v + (v << 1)) >> 8; + int8_t vs = 3*v >> 8; + return vs - 1; +} +} + +static const int8_t iq1bn_values[1280] = { + -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0, + -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1, + -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1, + -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1, + -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1, + -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1, + 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0, + -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1, + 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1, + 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1, + -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0, + 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1, + -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1, + 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1, + -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0, + 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1, + 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1, + 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0, + -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1, + 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1, + 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0, + 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0, + 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0, + 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0, + 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0, + 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0, + -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1, + 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1, + 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0, + 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1, + -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1, + 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1, + -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1, + 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, + 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1, + 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, +}; + +void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(nrc); + + static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64"); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + + const block_iq1_bn * x = (const block_iq1_bn *)vx; + + const float * d8 = (const float *)vy; + const int8_t * q8 = (const int8_t *)(d8 + 4); + int nblock = n / QK_IQ1BN; + + int sumi[8] = {}; + int8_t q1[16]; + + for (int ii = 0; ii < nblock; ii += 32) { + int16_t sum16[8] = {}; + int nb = std::min(ii + 32, nblock); + for (int i = ii; i < nb; ++i) { + auto ql = x[i].ql; + const int8_t * extra = iq1bn_values + 5*x[i].extra; + for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { + for (int k = 0; k < 3; ++k) { + uint8_t q = *ql++; + const int8_t * vs = iq1bn_values + 5*q; + for (int j = 0; j < 5; ++j) q1[5*k+j] = vs[j]; + } + q1[15] = extra[i16]; + // We collect 8 q8 values per block into each element of sum16 + // => 32 x 8 = 256 values in each loop over i, so this cannot overflow the int16_t range + // (q8 is in -127...127, and hence the sum is in -32512...32512 + for (int j = 0; j < 8; ++j) sum16[j] += q8[2*j+0]*q1[2*j+0] + q8[2*j+1]*q1[2*j+1]; + q8 += 16; + } + } + for (int j = 0; j < 8; ++j) sumi[j] += sum16[j]; + } + + *s = d8[0] * (sumi[0] + sumi[1]) + d8[1] * (sumi[2] + sumi[3]) + d8[2] * (sumi[4] + sumi[5]) + d8[3] * (sumi[6] + sumi[7]); +} + +void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(nrc); + + static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64"); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) { + return; + } + + constexpr int Nj = QK_IQ1BN/4; + + const block_iq2_bn * x = (const block_iq2_bn *)vx; + int nblock = n / QK_IQ1BN; + + const float * d = (const float *)vy; + const int8_t * q8 = (const int8_t *)(d + 4); + + int sum[16] = { }; + int sum0[4] = { }; + + for (int i = 0; i < nblock; ++i) { + for (int j = 0; j < Nj/4; ++j) { + for (int l = 0; l < 4; ++l) { + sum[4*j + 0] += q8[4*j + l + 0] * (x[i].qs[4*j+l] & 0x03); + sum[4*j + 1] += q8[4*j + l + 1*Nj] * (x[i].qs[4*j+l] & 0x0c); + sum[4*j + 2] += q8[4*j + l + 2*Nj] * (x[i].qs[4*j+l] & 0x30); + sum[4*j + 3] += q8[4*j + l + 3*Nj] * (x[i].qs[4*j+l] & 0xc0); + sum0[j] += q8[4*j + l] + q8[4*j + l + 1*Nj] + q8[4*j + l + 2*Nj] + q8[4*j + l + 3*Nj]; + } + } + q8 += QK_IQ1BN; + } + + float sumf = 0; + for (int j = 0; j < 4; ++j) { + sumf += d[j] * (sum[4*j + 0] + 0.25f*sum[4*j + 1] + 0.0625*sum[4*j + 2] + 0.015625*sum[4*j + 3] - sum0[j]); + } + *s = sumf; + +} + +void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { + + float * dptr = (float *)y; + auto qs = (int8_t *)(dptr + 4); +#ifdef __ARM_NEON + static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + auto shuffle = vld1q_u8(k_shuffle); + float32x4_t max[4] = { }; + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + auto val = vld1q_f32(x + j + 4*i); + val = vabsq_f32(val); + max[i] = vmaxq_f32(max[i], val); + } + } + float32x4_t vid[4]; + for (int i = 0; i < 4; ++i) { + dptr[i] = vmaxvq_f32(max[i])/127; + float id = dptr[i] > 0 ? 1/dptr[i] : 0.f; + vid[i] = vdupq_n_f32(id); + } + int8x16x4_t q; + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + auto val = vld1q_f32(x + j + 4*i); + val = vmulq_f32(vid[i], val); + q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val)); + } + auto qi = vqtbl4q_s8(q, shuffle); + vst1q_s8(qs, qi); + qs += 16; + } +#elif defined __AVX__ + __m128 max[4] = {}; + __m128 sign_bit = _mm_set1_ps(-0.f); + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + auto val = _mm_loadu_ps(x + j + 4*i); + val = _mm_andnot_ps(sign_bit, val); + max[i] = _mm_max_ps(max[i], val); + } + } + __m128 vid[4]; + for (int i = 0; i < 4; ++i) { + max[i] = _mm_max_ps(max[i], _mm_movehl_ps(max[i], max[i])); + max[i] = _mm_max_ss(max[i], _mm_movehdup_ps(max[i])); + float maxi = _mm_cvtss_f32(max[i]); + dptr[i] = maxi/127; + float id = dptr[i] > 0 ? 1/dptr[i] : 0.f; + vid[i] = _mm_set1_ps(id); + } + __m128i q[4]; + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + auto val = _mm_loadu_ps(x + j + 4*i); + val = _mm_round_ps(_mm_mul_ps(vid[i], val), _MM_ROUND_NEAREST); + q[i] = _mm_cvtps_epi32(val); + } + auto q1 = _mm_packs_epi32(q[0], q[1]); + auto q2 = _mm_packs_epi32(q[2], q[3]); + auto qi = _mm_packs_epi16(q1, q2); + _mm_storeu_si128((__m128i *)qs, qi); + qs += 16; + } +#else + float aux[4] = {0.f, 0.f, 0.f, 0.f}; + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + for (int l = 0; l < 4; ++l) { + float ax = fabsf(x[j+4*i+l]); + aux[i] = std::max(aux[i], ax); + } + } + } + for (int i = 0; i < 4; ++i) { + dptr[i] = aux[i]/127; + aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f; + } + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); + } + } +#endif +} + +void quantize_row_q8_K64(const float * x, void * y, int64_t k) { + quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k); +} + +// +// ============================================== iq2_K +// + +namespace { + +inline int best_index_iq2nl(const int8_t * values, float x) { + int idx = x < values[1] ? 0 : x > values[2] ? 2 : 1; + return x - values[idx] < values[idx+1] - x ? idx : idx + 1; +} + +void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { + + constexpr int kBlockSize = 16; + + block_iq2_k * y = (block_iq2_k *)vy; + + float scales[QK_K/kBlockSize]; + float weight[kBlockSize]; + float sumx[kBlockSize+1], sumw[kBlockSize+1]; + + std::array, kBlockSize> pairs; + + const int8_t * shifted_values = iq2nl_values + 4; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 1.5f*sumx2/QK_K; + + uint16_t extra = 0; + + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + for (int j = 0; j < kBlockSize; ++j) pairs[j] = {xb[j], j}; + std::sort(pairs.begin(), pairs.end()); + sumx[0] = sumw[0] = 0; + for (int j = 0; j < kBlockSize; ++j) { + int jj = pairs[j].second; + sumw[j+1] = sumw[j] + weight[jj]; + sumx[j+1] = sumx[j] + weight[jj]*xb[jj]; + } + float best = 0, d = 0; + bool is_shifted = false; + float sumqx, sumq2; + for (int i1 = 0; i1 < kBlockSize; ++i1) { + for (int i2 = i1; i2 < kBlockSize; ++i2) { + for (int i3 = i2; i3 < kBlockSize; ++i3) { + sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[0] + (sumx[i2] - sumx[i1])*iq2nl_values[1] + + (sumx[i3] - sumx[i2])*iq2nl_values[2] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[3]; + sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[0]*iq2nl_values[0] + (sumw[i2] - sumw[i1])*iq2nl_values[1]*iq2nl_values[1] + + (sumw[i3] - sumw[i2])*iq2nl_values[2]*iq2nl_values[2] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[3]*iq2nl_values[3]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = false; + } + sumqx = (sumx[i1] - sumx[ 0])*shifted_values[0] + (sumx[i2] - sumx[i1])*shifted_values[1] + + (sumx[i3] - sumx[i2])*shifted_values[2] + (sumx[kBlockSize] - sumx[i3])*shifted_values[3]; + sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[0]*shifted_values[0] + (sumw[i2] - sumw[i1])*shifted_values[1]*shifted_values[1] + + (sumw[i3] - sumw[i2])*shifted_values[2]*shifted_values[2] + (sumw[kBlockSize] - sumw[i3])*shifted_values[3]*shifted_values[3]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = true; + } + sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[3] + (sumx[i2] - sumx[i1])*iq2nl_values[2] + + (sumx[i3] - sumx[i2])*iq2nl_values[1] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[0]; + sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[3]*iq2nl_values[3] + (sumw[i2] - sumw[i1])*iq2nl_values[2]*iq2nl_values[2] + + (sumw[i3] - sumw[i2])*iq2nl_values[1]*iq2nl_values[1] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[0]*iq2nl_values[0]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = false; + } + sumqx = (sumx[i1] - sumx[ 0])*shifted_values[3] + (sumx[i2] - sumx[i1])*shifted_values[2] + + (sumx[i3] - sumx[i2])*shifted_values[1] + (sumx[kBlockSize] - sumx[i3])*shifted_values[0]; + sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[3]*shifted_values[3] + (sumw[i2] - sumw[i1])*shifted_values[2]*shifted_values[2] + + (sumw[i3] - sumw[i2])*shifted_values[1]*shifted_values[1] + (sumw[kBlockSize] - sumw[i3])*shifted_values[0]*shifted_values[0]; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; is_shifted = true; + } + } + } + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + + float abs_scale = fabsf(scales[ib]); + max_abs_scale = MAX(max_abs_scale, abs_scale); + } + + if (!max_abs_scale) continue; + + float d = max_abs_scale/15; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + int ls = nearest_int(0.5f*(id*scales[ib]+15)); + ls = MAX(0, MIN(15, ls)); + y[ibl].scales[ib/2] |= (ls << 4*(ib%2)); + ls = 2*ls - 15; + float dl = d * ls; + if (dl) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values; + const float * xb = xbl + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/4) + offset; + for (int j = 0; j < 16; ++j) { + const float al = idl*xb[j]; + int ibest = best_index_iq2nl(block_values, al); + qs[j] |= (ibest << 2*(ib32%4)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + } +} +} + +void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_k * y = (block_iq2_k *)vy; + quantize_row_iq2_k_ref(x, y, k); +} + +size_t quantize_iq2_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq2_k_impl(src, (void *)qrow, n_per_row, imatrix); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_k); + } + return nrows * nblock * sizeof(block_iq2_k); +} + +void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + + uint16_t extra = x[i].extra; + + int shift = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + float dl1 = d * (2*(x[i].scales[ib32] & 0xf) - 15); + float dl2 = d * (2*(x[i].scales[ib32] >> 4) - 15); + const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values; + const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values; + extra >>= 2; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[(qs[j+ 0] >> shift) & 3]; + y[j+16] = dl2 * values2[(qs[j+16] >> shift) & 3]; + } + y += 32; + shift += 2; + if (shift == 8) { qs += 32; shift = 0; } + } + + } + +} + +void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq2_k * x = (const block_iq2_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; +} + +// +// ============================================== iq3_k +// +namespace { +const int8_t iq3nl_index[111] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, + 9, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 10, 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 11, 11, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 12, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 13, 13, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 14, 14, 7, 7, 7, 7, 7, 7, 7, 7, 7 +}; +inline int best_index_iq3nl(const int8_t * values, float x) { + int ix = (int)x - values[0]; + if (ix < 0 || ix >= 111) return ix < 0 ? 0 : 7; + ix = iq3nl_index[ix]; + return ix < 8 ? ix : x - values[ix-8] < values[ix-7] - x ? ix-8 : ix-7; +} + +static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { + + const int ntry = 5; + + block_iq3_k * y = (block_iq3_k *)vy; + + float scales[QK_K/16]; + float weight[16]; + + const int8_t * shifted_values = iq3nl_values + 8; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq3_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = sumx2/QK_K; + + uint16_t extra = 0; + + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < 16; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/iq3nl_values[0] : max/iq3nl_values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(iq3nl_values, al); + float q = iq3nl_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(iq3nl_values, -al); + q = iq3nl_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + bool is_shifted = false; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + iq3nl_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(iq3nl_values, al); + float q = iq3nl_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(iq3nl_values, -al); + q = iq3nl_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq3nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (d) { + const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(block_values, al); + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + } + scales[ib] = d; + + if (is_shifted) extra |= (1 << ib); + + float abs_scale = fabsf(scales[ib]); + max_abs_scale = MAX(max_abs_scale, abs_scale); + } + + if (!max_abs_scale) continue; + + float d = max_abs_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/16; ++ib) { + int ls = nearest_int(0.5f*(id*fabsf(scales[ib])-1)); + ls = MAX(0, MIN(15, ls)); + y[ibl].scales_l[ib/2] |= (ls << 4*(ib%2)); + if (scales[ib] < 0) y[ibl].scales_h |= (1 << ib); + ls = (2*ls + 1) * (scales[ib] < 0 ? -1 : 1); + float dl = d * ls; + if (dl) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq3nl_values; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/4) + offset; + uint8_t * qh = y[ibl].qh + 32*(ib32/8) + offset; + for (int j = 0; j < 16; ++j) { + const float al = idl*xb[j]; + int ibest = best_index_iq3nl(block_values, al); + qs[j] |= ((ibest & 3) << 2*(ib32%4)); + qh[j] |= ((ibest >> 2) << (ib32%8)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + } +} + +} + +void quantize_row_iq3_k_ref(const float * x, block_iq3_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq3_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_k * y = (block_iq3_k *)vy; + quantize_row_iq3_k_ref(x, y, k); +} + +size_t quantize_iq3_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq3_k_impl(src, (void *)qrow, n_per_row, imatrix); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_k); + } + return nrows * nblock * sizeof(block_iq3_k); +} + +void dequantize_row_iq3_k(const block_iq3_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + + uint16_t sh = x[i].scales_h; + uint16_t extra = x[i].extra; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + float dl1 = d * ((2*(x[i].scales_l[ib32] & 0xf) + 1) * ((sh & 1) ? -1 : 1)); + float dl2 = d * ((2*(x[i].scales_l[ib32] >> 4) + 1) * ((sh & 2) ? -1 : 1)); + sh >>= 2; + const int8_t * values1 = extra & 1 ? iq3nl_values + 8 : iq3nl_values; + const int8_t * values2 = extra & 2 ? iq3nl_values + 8 : iq3nl_values; + extra >>= 2; + int shift_l = 2*(ib32%4); + int shift_h = ib32%8; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[((qs[j+ 0] >> shift_l) & 3) | (((qh[j+ 0] >> shift_h) & 1) << 2)]; + y[j+16] = dl2 * values2[((qs[j+16] >> shift_l) & 3) | (((qh[j+16] >> shift_h) & 1) << 2)]; + } + y += 32; + if (shift_l == 6) qs += 32; + } + + } +} + +void vec_dot_iq3_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq2_k * x = (const block_iq2_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; +} + +// +// ============================================== iq4_K +// +void dequantize_row_iq4_k(const block_iq4_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + + uint16_t extra = x[i].extra; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2); + const float dl1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32); + const float dl2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); + const int8_t * values1 = extra & 1 ? iq4k_values + 16 : iq4k_values; + const int8_t * values2 = extra & 2 ? iq4k_values + 16 : iq4k_values; + extra >>= 2; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[qs[j] & 0xf]; + y[j+16] = dl2 * values2[qs[j] >> 4]; + } + y += 32; + qs += 16; + } + } +} + +void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq4_k * x = (const block_iq4_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t extra = x[ibl].extra; + uint32_t h = *((const uint32_t *)x[ibl].scales_h); + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + int32_t sum = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + const int8_t * values1 = iq4k_values + 16*(extra & 1); + const int8_t * values2 = iq4k_values + 8*(extra & 2); + extra >>= 2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * values1[qs[j] & 0xf]; + sumi2 += q8[j+16] * values2[qs[j] >> 4]; + } + sum += ls1*sumi1 + ls2*sumi2; + qs += 16; + q8 += 32; + } + sumf += d4d8 * sum; + } + *s = sumf; + +} + +namespace { +const int8_t iq4nl_index[241] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15 +}; +inline int best_index_iq4nl(const int8_t * values, float x) { + int ix = (int)x - values[0]; + if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15; + ix = iq4nl_index[ix]; + return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15; +} + +static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int block_size, const float * x, + block_iq4_k * y, + float * scales, float * weight, uint8_t * L, + const int8_t * values, + const float * quant_weights, + const int ntry) { + + GGML_ASSERT(super_block_size == 256 && block_size == 16); + + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; + + memset(y, 0, sizeof(block_iq4_k)); + y->d = GGML_FP32_TO_FP16(0.f); + + uint16_t * scales_h = (uint16_t *)y->scales_h; + + const int8_t * shifted_values = values + 16; + + float max_scale = 0, amax_scale = 0; + uint16_t extra = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = x + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + bool is_shifted = false; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (is_shifted) extra |= (1 << ib); + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; + } + } + float d = -max_scale/32; + y->d = GGML_FP32_TO_FP16(d); + y->extra = extra; + float id = d ? 1/d : 0.f; + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const int8_t * block_values = extra & (1 << ib) ? shifted_values : values; + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_iq4nl(block_values, idl*xb[j]); + float w = weight[j]; + float q = block_values[Lb[j]]*l; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) y->scales_l[ib/2] = l_l; + else y->scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + if (sumq2 > 0) y->d = GGML_FP32_TO_FP16(sumqx/sumq2); + + for (int i = 0; i < super_block_size/32; ++i) { + for (int j = 0; j < 16; ++j) { + y->qs[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); + } + } +} + +} + +void quantize_row_iq4_k_ref(const float * x, block_iq4_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq4_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_k * y = (block_iq4_k *)vy; + quantize_row_iq4_k_ref(x, y, k); +} + +size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + uint8_t L[QK_K]; + float weight[16]; + float scales[QK_K/16]; + for (int64_t row = 0; row < nrows; ++row) { + block_iq4_k * iq4 = (block_iq4_k *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = imatrix ? imatrix + QK_K*ibl : NULL; + quantize_row_iq4_k_impl_bs16(QK_K, 16, src + QK_K*ibl, iq4 + ibl, + scales, weight, L, iq4k_values, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_k); + } + return nrows * nblock * sizeof(block_iq4_k); +} + +// +// ============================================== iq5_K +// +void dequantize_row_iq5_k(const block_iq5_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * sl = x[i].scales_l; + const uint8_t * sh = x[i].scales_h; + + uint16_t extra = x[i].extra; + + int shift = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + + float dl1 = d * (((sl[2*ib64+0] & 0xf) | ((sh[ib64] << 4) & 0x30)) - 32); + float dl2 = d * (((sl[2*ib64+0] >> 4) | ((sh[ib64] << 2) & 0x30)) - 32); + float dl3 = d * (((sl[2*ib64+1] & 0xf) | ((sh[ib64] >> 0) & 0x30)) - 32); + float dl4 = d * (((sl[2*ib64+1] >> 4) | ((sh[ib64] >> 2) & 0x30)) - 32); + const int8_t * values1 = iq5nl_values + ((extra & 1) << 5); + const int8_t * values2 = iq5nl_values + ((extra & 2) << 4); + const int8_t * values3 = iq5nl_values + ((extra & 4) << 3); + const int8_t * values4 = iq5nl_values + ((extra & 8) << 2); + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * values1[(qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4)]; + y[j+16] = dl2 * values2[(qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4)]; + y[j+32] = dl3 * values3[(qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3)]; + y[j+48] = dl4 * values4[(qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3)]; + } + y += 64; + qs += 32; + extra >>= 4; + shift += 2; + if (shift == 8) { qh += 32; shift = 0; } + } + + } +} + +void vec_dot_iq5_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ5_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq5_k * x = (const block_iq5_k *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + float sumf = 0; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * sl = x[i].scales_l; + const uint8_t * sh = x[i].scales_h; + const int8_t * q8 = y[i].qs; + + uint16_t extra = x[i].extra; + + int shift = 0; + int sumb = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + + int dl1 = (((sl[2*ib64+0] & 0xf) | ((sh[ib64] << 4) & 0x30)) - 32); + int dl2 = (((sl[2*ib64+0] >> 4) | ((sh[ib64] << 2) & 0x30)) - 32); + int dl3 = (((sl[2*ib64+1] & 0xf) | ((sh[ib64] >> 0) & 0x30)) - 32); + int dl4 = (((sl[2*ib64+1] >> 4) | ((sh[ib64] >> 2) & 0x30)) - 32); + const int8_t * values1 = iq5nl_values + ((extra & 1) << 5); + const int8_t * values2 = iq5nl_values + ((extra & 2) << 4); + const int8_t * values3 = iq5nl_values + ((extra & 4) << 3); + const int8_t * values4 = iq5nl_values + ((extra & 8) << 2); + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * values1[(qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4)]; + sumi2 += q8[j+16] * values2[(qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4)]; + sumi3 += q8[j+32] * values3[(qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3)]; + sumi4 += q8[j+48] * values4[(qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3)]; + } + sumb += dl1 * sumi1 + dl2 * sumi2 + dl3 * sumi3 + dl4 * sumi4; + q8 += 64; + qs += 32; + extra >>= 4; + shift += 2; + } + sumf += d * sumb; + + } + + *s = sumf; + +} + +namespace { +const int8_t iq5nl_index[248] = { + 0, 0, 0, 0, 0, 0, 32, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 33, 33, 2, 2, 2, 2, 2, 2, 2, 2, 2, 34, 34, 3, 3, + 3, 3, 3, 3, 3, 3, 35, 35, 4, 4, 4, 4, 4, 4, 4, 36, 36, 5, 5, 5, 5, 5, 5, 5, 37, 37, 6, 6, 6, 6, 6, 6, + 6, 38, 7, 7, 7, 7, 7, 7, 39, 39, 8, 8, 8, 8, 8, 40, 40, 9, 9, 9, 9, 9, 41, 41, 10, 10, 10, 10, 10, 42, 11, 11, + 11, 11, 11, 43, 12, 12, 12, 12, 12, 44, 13, 13, 13, 13, 13, 45, 14, 14, 14, 14, 14, 46, 15, 15, 15, 15, 47, 47, 16, 16, 16, 16, + 48, 17, 17, 17, 17, 17, 49, 18, 18, 18, 18, 18, 50, 19, 19, 19, 19, 19, 51, 20, 20, 20, 20, 20, 52, 21, 21, 21, 21, 21, 53, 53, + 22, 22, 22, 22, 22, 54, 54, 23, 23, 23, 23, 23, 23, 55, 24, 24, 24, 24, 24, 24, 24, 56, 25, 25, 25, 25, 25, 25, 25, 57, 57, 26, + 26, 26, 26, 26, 26, 26, 58, 58, 27, 27, 27, 27, 27, 27, 27, 27, 59, 28, 28, 28, 28, 28, 28, 28, 28, 28, 60, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 61, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 62, 31, 31, 31, 31, 31, 31 +}; +inline int best_index_iq5nl(const int8_t * values, float x) { + int ix = (int)x - values[0]; + if (ix < 0 || ix >= 247) return ix < 0 ? 0 : 31; + ix = iq5nl_index[ix]; + return ix < 32 ? ix : x - values[ix-32] < values[ix-31] - x ? ix-32 : ix-31; +} + +void quantize_row_iq5_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { + const int ntry = 5; + const float step = 1.f; + + block_iq5_k * y = (block_iq5_k *)vy; + + float scales[QK_K/16]; + float weight[16]; + + const int8_t * shifted_values = iq5nl_values + 32; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq5_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 2*sumx2/QK_K; + + float max_scale = 0, max_abs_scale = 0; + uint16_t extra = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < 16; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/iq5nl_values[0] : max/iq5nl_values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq5nl(iq5nl_values, al); + float q = iq5nl_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq5nl(iq5nl_values, -al); + q = iq5nl_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + bool is_shifted = false; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry*step + iq5nl_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq5nl(iq5nl_values, al); + float q = iq5nl_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq5nl(iq5nl_values, -al); + q = iq5nl_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry*step + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq5nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq5nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (d) { + const int8_t * block_values = is_shifted ? shifted_values : iq5nl_values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq5nl(block_values, al); + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + + float abs_scale = fabsf(scales[ib]); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; max_scale = scales[ib]; + } + + } + + if (!max_abs_scale) continue; + float d = -max_scale/32; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + + float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/16; ++ib) { + int ls = nearest_int(id*scales[ib]); + ls = MAX(-32, MIN(31, ls)); + int uls = ls + 32; + y[ibl].scales_l[ib/2] |= ((uls & 0xf) << 4*(ib%2)); + y[ibl].scales_h[ib/4] |= ((uls >> 4) << 2*(ib%4)); + float dl = d * ls; + if (dl) { + const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq5nl_values; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/2) + offset; + uint8_t * qh = y[ibl].qh + 32*(ib32/8) + offset; + for (int j = 0; j < 16; ++j) { + const float al = idl*xb[j]; + int ibest = best_index_iq5nl(block_values, al); + qs[j] |= ((ibest & 0xf) << 4*(ib32%2)); + qh[j] |= ((ibest >> 4) << (ib32%8)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + } + +} + +} + +void quantize_row_iq5_k_ref(const float * x, block_iq5_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq5_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq5_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq5_k * y = (block_iq5_k *)vy; + quantize_row_iq5_k_ref(x, y, k); +} + +size_t quantize_iq5_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq5_k_impl(src, (void *)qrow, n_per_row, imatrix); + src += n_per_row; + qrow += nblock*sizeof(block_iq5_k); + } + return nrows * nblock * sizeof(block_iq5_k); +} + +// +// ============================================== iq6_K +// +#define A_IQ6K -127.f +#define B_IQ6K 6.2568f +#define C_IQ6K 0.11218f +#define D_IQ6K 0.0011972f +#define S_IQ6K 1.f + +void dequantize_row_iq6_k(const block_iq6_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const int8_t * sl = x[i].scales; + + uint16_t extra = x[i].extra; + + int shift = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + + float dl1 = d * sl[4*ib64 + 0]; + float dl2 = d * sl[4*ib64 + 1]; + float dl3 = d * sl[4*ib64 + 2]; + float dl4 = d * sl[4*ib64 + 3]; + float m1 = extra & 1 ? S_IQ6K : 0; + float m2 = extra & 2 ? S_IQ6K : 0; + float m3 = extra & 4 ? S_IQ6K : 0; + float m4 = extra & 8 ? S_IQ6K : 0; + for (int j = 0; j < 16; ++j) { + float q1 = ((qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 0x03) << 4)); + float q2 = ((qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 0x03) << 4)); + float q3 = ((qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 0x0c) << 2)); + float q4 = ((qs[j+16] >> 4) | (((qh[j+16] >> shift) & 0x0c) << 2)); + y[j+ 0] = dl1 * (A_IQ6K + q1*(B_IQ6K + q1*(-C_IQ6K + q1*D_IQ6K)) + m1); + y[j+16] = dl2 * (A_IQ6K + q2*(B_IQ6K + q2*(-C_IQ6K + q2*D_IQ6K)) + m2); + y[j+32] = dl3 * (A_IQ6K + q3*(B_IQ6K + q3*(-C_IQ6K + q3*D_IQ6K)) + m3); + y[j+48] = dl4 * (A_IQ6K + q4*(B_IQ6K + q4*(-C_IQ6K + q4*D_IQ6K)) + m4); + } + y += 64; + qs += 32; + extra >>= 4; + shift += 4; + if (shift == 8) { qh += 32; shift = 0; } + } + + } +} + +void vec_dot_iq6_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ6_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + // TODO + //const int nb = n / QK_K; + + //const block_iq5_k * x = (const block_iq5_k *)vx; + //const block_q8_K * y = (const block_q8_K *)vy; + + //float sumf = 0; + + //for (int i = 0; i < nb; i++) { + + // const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + // const uint8_t * qs = x[i].qs; + // const uint8_t * qh = x[i].qh; + // const uint8_t * sl = x[i].scales_l; + // const uint8_t * sh = x[i].scales_h; + // const int8_t * q8 = y[i].qs; + + // uint16_t extra = x[i].extra; + + // int shift = 0; + // int sumb = 0; + // for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + + // int dl1 = (((sl[2*ib64+0] & 0xf) | ((sh[ib64] << 4) & 0x30)) - 32); + // int dl2 = (((sl[2*ib64+0] >> 4) | ((sh[ib64] << 2) & 0x30)) - 32); + // int dl3 = (((sl[2*ib64+1] & 0xf) | ((sh[ib64] >> 0) & 0x30)) - 32); + // int dl4 = (((sl[2*ib64+1] >> 4) | ((sh[ib64] >> 2) & 0x30)) - 32); + // const int8_t * values1 = iq5nl_values + ((extra & 1) << 5); + // const int8_t * values2 = iq5nl_values + ((extra & 2) << 4); + // const int8_t * values3 = iq5nl_values + ((extra & 4) << 3); + // const int8_t * values4 = iq5nl_values + ((extra & 8) << 2); + // int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + // for (int j = 0; j < 16; ++j) { + // sumi1 += q8[j+ 0] * values1[(qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4)]; + // sumi2 += q8[j+16] * values2[(qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4)]; + // sumi3 += q8[j+32] * values3[(qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3)]; + // sumi4 += q8[j+48] * values4[(qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3)]; + // } + // sumb += dl1 * sumi1 + dl2 * sumi2 + dl3 * sumi3 + dl4 * sumi4; + // q8 += 64; + // qs += 32; + // extra >>= 4; + // shift += 2; + // } + // sumf += d * sumb; + + //} + + //*s = sumf; + +} + +namespace { + +inline int best_index(int n, const float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} +uint8_t iq6nl_index[249] = { + 0, 0, 0, 64, 1, 1, 1, 1, 1, 65, 2, 2, 2, 2, 2, 66, 3, 3, 3, 3, 67, 67, 4, 4, 4, 4, 68, 5, 5, 5, 5, 69, + 69, 6, 6, 6, 70, 70, 7, 7, 7, 71, 8, 8, 8, 72, 72, 9, 9, 9, 73, 73, 10, 10, 10, 74, 11, 11, 11, 75, 12, 12, 12, 76, + 13, 13, 13, 77, 14, 14, 14, 78, 15, 15, 79, 79, 16, 16, 80, 17, 17, 81, 81, 18, 18, 82, 19, 19, 83, 83, 20, 84, 84, 21, 85, 85, + 22, 86, 86, 23, 87, 87, 24, 88, 88, 25, 89, 89, 26, 90, 90, 27, 91, 91, 28, 92, 29, 93, 93, 30, 94, 94, 31, 95, 95, 32, 96, 33, + 97, 97, 34, 98, 98, 35, 99, 99, 36, 100, 100, 37, 101, 38, 102, 102, 39, 103, 103, 40, 104, 104, 41, 41, 105, 42, 42, 106, 106, 43, 107, 107, + 44, 108, 108, 45, 45, 109, 46, 46, 46, 110, 47, 47, 111, 111, 48, 48, 112, 49, 49, 49, 113, 50, 50, 50, 114, 51, 51, 51, 115, 52, 52, 52, + 116, 116, 53, 53, 53, 117, 54, 54, 54, 118, 118, 55, 55, 55, 119, 119, 56, 56, 56, 120, 120, 57, 57, 57, 121, 121, 58, 58, 58, 58, 122, 59, + 59, 59, 59, 123, 123, 60, 60, 60, 60, 124, 61, 61, 61, 61, 61, 125, 62, 62, 62, 62, 62, 126, 63, 63, 63, +}; +inline int best_index_iq6nl(const float * values, float x) { + int ix = (int)(x - values[0]); + if (ix < 0 || ix >= 249) return ix < 0 ? 0 : 63; + ix = iq6nl_index[ix]; + return ix < 64 ? ix : x - values[ix-64] < values[ix-63] - x ? ix-64 : ix-63; + //if (x <= val[0]) return 0; + //if (x >= val[63]) return 63; + //int index = iq6nl_index[int(x - val[0])]; + //return index < 64 ? index : x - val[index-64] < val[index-63] - x ? index - 64 : index - 63; +} + + +void quantize_row_iq6_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, const float * values, const float * shifted_values) { + const int ntry = 5; + const float step = 1.f; + + block_iq6_k * y = (block_iq6_k *)vy; + + float scales[QK_K/16]; + float weight[16]; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq6_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 2*sumx2/QK_K; + + float max_scale = 0, max_abs_scale = 0; + uint16_t extra = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < 16; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + //int l = best_index(64, values, al); + int l = best_index_iq6nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + //l = best_index(64, values, -al); + l = best_index_iq6nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + bool is_shifted = false; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry*step + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + //int l = best_index(64, values, al); + int l = best_index_iq6nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + //l = best_index(64, values, -al); + l = best_index_iq6nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry*step + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + //int l = best_index(64, shifted_values, al); + int l = best_index_iq6nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + //l = best_index(64, shifted_values, -al); + l = best_index_iq6nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (d) { + const float * block_values = is_shifted ? shifted_values : values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + //int l = best_index(64, block_values, al); + int l = best_index_iq6nl(block_values, al); + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + + float abs_scale = fabsf(scales[ib]); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; max_scale = scales[ib]; + } + + } + + if (!max_abs_scale) continue; + float d = -max_scale/127; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + + float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/16; ++ib) { + int ls = nearest_int(id*scales[ib]); + ls = MAX(-127, MIN(127, ls)); + y[ibl].scales[ib] |= ls; + float dl = d * ls; + if (dl) { + const float * block_values = y[ibl].extra & (1 << ib) ? shifted_values : values; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 1/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/2) + offset; + uint8_t * qh = y[ibl].qh + 32*(ib32/4) + offset; + for (int j = 0; j < 16; ++j) { + const float al = idl*xb[j]; + //int ibest = best_index(64, block_values, al); + int ibest = best_index_iq6nl(block_values, al); + qs[j] |= ((ibest & 0xf) << 4*(ib32%2)); + qh[j] |= ((ibest >> 4) << 2*(ib32%4)); + float w = weight[j]; + float q = block_values[ibest]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + } + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + } +} + +} + +void quantize_row_iq6_k_ref(const float * x, block_iq6_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq6_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq6_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq6_k * y = (block_iq6_k *)vy; + quantize_row_iq6_k_ref(x, y, k); +} + +size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + float values[128]; + for (int i = 0; i < 64; ++i) { + values[i] = iq6nl_values[i]; + values[i+64] = values[i] + S_IQ6K; + } + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq6_k_impl(src, (void *)qrow, n_per_row, imatrix, values, values + 64); + src += n_per_row; + qrow += nblock*sizeof(block_iq6_k); + } + return nrows * nblock * sizeof(block_iq6_k); +} + +// +// ========================== IQ2_TN +// + +void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { + GGML_ASSERT(k%QK_K == 0); + + int nb = k/QK_K; + + auto quantize = [] (float xmax, float x) { + return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2; + }; + + for (int ibl = 0; ibl < nb; ++ibl) { + auto xb = x + QK_K*ibl; + float max = xb[0]; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(xb[j]); + max = std::max(ax, max); + } + y[ibl].d = GGML_FP32_TO_FP16(max); + auto qs = y[ibl].qs; + for (int l = 0; l < QK_K/128; ++l) { + for (int j = 0; j < 32; ++j) { + qs[j] = quantize(max, xb[j]) | (quantize(max, xb[j+32]) << 2) | (quantize(max, xb[j+64]) << 4) | (quantize(max, xb[j+96]) << 6); + } + xb += 128; + qs += 32; + } + } +} + +void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { + quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k); +} + +size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * /*imatrix*/) { + auto row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; ++row) { + quantize_row_iq2_tn_ref(src, (block_iq2_tn *)qrow, n_per_row); + qrow += row_size; + src += n_per_row; + } + return row_size*nrows; +} + +void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) { + GGML_ASSERT(k%QK_K == 0); + int nb = k/QK_K; + for (int ibl = 0; ibl < nb; ++ibl) { + float d = GGML_FP16_TO_FP32(x[ibl].d); + auto qs = x[ibl].qs; + for (int l = 0; l < QK_K/128; ++l) { + for (int j = 0; j < 32; ++j) { + y[j+ 0] = d*((qs[j] >> 0) & 3) - d; + y[j+32] = d*((qs[j] >> 2) & 3) - d; + y[j+64] = d*((qs[j] >> 4) & 3) - d; + y[j+96] = d*((qs[j] >> 6) & 3) - d; + } + y += 128; + qs += 32; + } + } +} + +void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq2_tn * x = (const block_iq2_tn *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + float sumf = 0; + + for (int i = 0; i < nb; i++) { + float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + auto qs = x[i].qs; + auto q8 = y[i].qs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0; + for (int j = 0; j < QK_K/16; ++j) sumi1 -= y[i].bsums[j]; + for (int l = 0; l < QK_K/128; ++l) { + for (int j = 0; j < 32; ++j) { + sumi1 += q8[j+ 0] * (qs[j] & 0x03); + sumi2 += q8[j+32] * (qs[j] & 0x0c); + sumi3 += q8[j+64] * (qs[j] & 0x30); + sumi4 += q8[j+96] * (qs[j] & 0xc0); + } + q8 += 128; + qs += 32; + } + sumf += d * (sumi1 + 0.25f*sumi2 + 0.0625f*sumi3 + 0.015625f*sumi4); + } + *s = sumf; +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h new file mode 100644 index 0000000000000..3c5d27a458d5e --- /dev/null +++ b/ggml/src/iqk/iqk_quantize.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#ifdef __cplusplus +#define GGML_RESTRICT +extern "C" { +#else +#define GGML_RESTRICT restrict +#endif + +void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq3_k_ref(const float * GGML_RESTRICT x, block_iq3_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq3_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq3_k(const block_iq3_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq3_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq4_k_ref(const float * GGML_RESTRICT x, block_iq4_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_k(const block_iq4_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq5_k_ref(const float * GGML_RESTRICT x, block_iq5_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq5_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq5_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq5_k(const block_iq5_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq5_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq6_k_ref(const float * GGML_RESTRICT x, block_iq6_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq6_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq6_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq6_k(const block_iq6_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq6_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index 6626ceb26213f..9d56af78da7a0 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -845,6 +845,7 @@ class tinyBLAS_Q0_AVX { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ + bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e262368662423..0a0629675c5f6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1170,6 +1170,17 @@ class GGMLQuantizationType(IntEnum): TQ2_0 = 35 Q1_3 = 36 Q2_2 = 37 + IQ1_BN = 134, + IQ2_BN = 135, + Q8_K64 = 136, + IQ2_K = 137, + IQ3_K = 138, + IQ4_K = 139, + IQ5_K = 140, + IQ6_K = 141, + IQ2_TN = 142, + + # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1209,12 +1220,23 @@ class LlamaFileType(IntEnum): MOSTLY_IQ2_M = 29 # except 1d tensors MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors + MOSTLY_IQ1_M = 31 # except 1d tensors + MOSTLY_IQ1_XS = 99 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors MOSTLY_Q4_0_4_4 = 33 # except 1d tensors MOSTLY_Q4_0_4_8 = 34 # except 1d tensors MOSTLY_Q4_0_8_8 = 35 # except 1d tensors MOSTLY_Q2_2 = 36 # except 1d tensors MOSTLY_Q1_3 = 37 # except 1d tensors + MOSTLY_IQ1_BN = 136, # except 1d tensors + MOSTLY_IQ2_BN = 137, # except 1d tensors + MOSTLY_IQ2_K = 138, # except 1d tensors + MOSTLY_IQ3_K = 139, # except 1d tensors + MOSTLY_IQ4_K = 140, # except 1d tensors + MOSTLY_IQ5_K = 141, # except 1d tensors + MOSTLY_IQ6_K = 142, # except 1d tensors + MOSTLY_IQ2_TN = 143, # except 1d tensors + GUESSED = 1024 # not specified in the model file diff --git a/gguf-py/tests/test_quants.py b/gguf-py/tests/test_quants.py old mode 100755 new mode 100644 diff --git a/include/llama.h b/include/llama.h index acf3b75fabff0..6591d75066f7c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -232,6 +232,14 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q8_OE8 = 97, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_OE16 = 98, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_XS = 99, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_BN = 137, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_K = 138, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_K = 139, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_K = 140, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ5_K = 141, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ6_K = 142, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_TN = 143, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -417,6 +425,7 @@ extern "C" { bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards + bool ignore_imatrix_rules; // If set to true, the built-in rules for refusing to quantize into certain quants without imatrix are ignored void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; diff --git a/otherarch/sdcpp/SDCPP_LICENSE b/otherarch/sdcpp/SDCPP_LICENSE index 56e1e5a63852c..79a4a90eda21f 100644 --- a/otherarch/sdcpp/SDCPP_LICENSE +++ b/otherarch/sdcpp/SDCPP_LICENSE @@ -1,6 +1,8 @@ MIT License Copyright (c) 2023 leejet +Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2024 Iwan Kawrakow Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/llama.cpp b/src/llama.cpp index 628e1f66cdda5..19aac55c7649b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -40,6 +40,7 @@ // TODO: replace with ggml API call #define QK_K 256 +#define QK_IQ1BN 64 #ifdef __has_include #if __has_include() @@ -3816,8 +3817,16 @@ struct llama_model_loader { case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; + case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; + case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; + case GGML_TYPE_IQ2_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_TN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; + case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; + case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; + case GGML_TYPE_IQ5_K: ftype = LLAMA_FTYPE_MOSTLY_IQ5_K; break; + case GGML_TYPE_IQ6_K: ftype = LLAMA_FTYPE_MOSTLY_IQ6_K; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; @@ -4530,6 +4539,14 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; + case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQT_BN - 2.06 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; @@ -5015,6 +5032,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_PHI3: { + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { @@ -5501,14 +5519,21 @@ static void llm_load_vocab( vocab.tokenizer_add_space_prefix = false; vocab.tokenizer_clean_spaces = true; if (tokenizer_pre.empty()) { - LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + // OK - I don't feel like recreati8ng the LLaMA-v3 models. Considering that, at least for now, + // LLaMA-v3 is the only model wehere we end up here, let's just force the pre-tokanizer to be + // llama3. + tokenizer_pre = "llama3"; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'llama3'\n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); LLAMA_LOG_WARN("%s: ************************************ \n", __func__); - LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY MAY BE DEGRADED! \n", __func__); LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + //vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! } else if (tokenizer_pre == "default") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( @@ -6078,7 +6103,6 @@ static bool llm_load_tensors( // there is very little benefit to offloading the input layer, so always keep it on the CPU model.buft_input = llama_default_buffer_type_cpu(true); - //model.buft_input = llama_default_buffer_type_offload(main_gpu); model.buft_layer.resize(n_layer); @@ -7523,36 +7547,39 @@ static bool llm_load_tensors( // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading } + const uint32_t n_ff = hparams.n_ff(); + model.layers.resize(n_layer); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; case LLM_ARCH_T5: @@ -7861,6 +7888,27 @@ static bool llm_load_tensors( } } + if (model.arch == LLM_ARCH_BITNET) { + auto set_scale = [] (ggml_tensor * w, ggml_tensor * s) { + float scale = 1; + if (ggml_backend_buffer_is_host(s->buffer)) { + scale = *(const float *)s->data; + } else { + ggml_backend_tensor_get(s, &scale, 0, sizeof(float)); + } + std::memcpy(w->op_params, &scale, sizeof(scale)); + }; + for (auto& l : model.layers) { + set_scale(l.ffn_up, l.ffn_up_scale); + set_scale(l.ffn_gate, l.ffn_gate_scale); + set_scale(l.ffn_down, l.ffn_down_scale); + set_scale(l.wq, l.wq_scale); + set_scale(l.wk, l.wk_scale); + set_scale(l.wv, l.wv_scale); + set_scale(l.wo, l.wo_scale); + } + } + // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = ggml_time_us() - model.t_start_us; @@ -8090,10 +8138,10 @@ static struct ggml_tensor * llm_build_norm( struct ggml_tensor * mb, llm_norm_type type, const llm_build_cb & cb, - int il) { + int il, float scale_eps = 1) { switch (type) { case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; - case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, scale_eps * hparams.f_norm_rms_eps); break; } if (mw || mb) { @@ -13271,8 +13319,11 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + float q_scale; std::memcpy(&q_scale, model.layers[il].wq->op_params, sizeof(float)); + // Note: we could save this scale operation by applying the Q scale on the K * Q product further down + // (which also uses a scale). This works on the CPU and Metal backends, but produces NaNs on CUDA. + Qcur = ggml_scale(ctx0, Qcur, q_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -13280,8 +13331,9 @@ struct llm_build_context { } // B1.K - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + float k_scale; std::memcpy(&k_scale, model.layers[il].wk->op_params, sizeof(float)); + Kcur = ggml_scale(ctx0, Kcur, k_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -13289,12 +13341,14 @@ struct llm_build_context { } // B1.V - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + float v_scale; std::memcpy(&v_scale, model.layers[il].wv->op_params, sizeof(float)); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { + Vcur = ggml_scale(ctx0, Vcur, v_scale); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); + v_scale = 1; } Qcur = ggml_rope_ext( @@ -13311,21 +13365,85 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - NULL, NULL, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].attn_sub_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "attn_sub_norm", il); + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + // We would use this if we did not apply the Q scale above. Sadly, this fails on CUDA. + //float kq_scale = q_scale/sqrtf(float(n_embd_head)); + struct ggml_tensor * cur_attn; + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + cb(q, "q", il); - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + if (cparams.flash_attn) { + + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), + 0); + cb(v, "v", il); + + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + + cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv_self.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur_attn, "kqv_merged_cont", il); } - cb(cur, "attn_o_out", il); + + cur_attn = llm_build_norm(ctx0, cur_attn, hparams, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il, 1/(v_scale*v_scale)); + cb(cur_attn, "attn_sub_norm", il); + + ggml_build_forward_expand(gf, cur_attn); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); + float wo_scale; std::memcpy(&wo_scale, model.layers[il].wo->op_params, sizeof(float)); + cur = ggml_scale(ctx0, cur, wo_scale); + + cb(cur, "kqv_out", il); } if (il == n_layer - 1) { @@ -13339,28 +13457,41 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward forward - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, lctx, cur, - model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, - model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, - NULL, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_sub_out", il); + struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + float ffn_up_scale; std::memcpy(&ffn_up_scale, model.layers[il].ffn_up->op_params, sizeof(float)); - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].ffn_sub_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_sub_norm", il); + cb(tmp, "ffn_up", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + float ffn_gate_scale; std::memcpy(&ffn_gate_scale, model.layers[il].ffn_gate->op_params, sizeof(float)); + cur = ggml_scale(ctx0, cur, ffn_gate_scale); + + cb(cur, "ffn_gate", il); + + + // combine this with the above scale into ggml_scaled_silu + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); - cb(cur, "ffn_down", il); + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il, 1/(ffn_up_scale*ffn_up_scale)); + cb(cur, "ffn_sub_norm", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + float ffn_down_scale; std::memcpy(&ffn_down_scale, model.layers[il].ffn_down->op_params, sizeof(float)); + cur = ggml_scale(ctx0, cur, ffn_down_scale); + cb(cur, "ffn_down", il); + } cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); @@ -13376,7 +13507,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -14622,7 +14753,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { // TODO: use a per-batch flag for logits presence instead const bool has_logits = !cparams.embeddings; - const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; @@ -15762,7 +15893,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || - ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { new_type = GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_XS) { @@ -15869,6 +16000,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_Q8_OE16) { new_type = GGML_TYPE_F16; } + else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) { + new_type = GGML_TYPE_Q6_K; + } } } else if (name == "token_embd.weight") { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { @@ -15998,6 +16132,12 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ9_LR || ftype == LLAMA_FTYPE_MOSTLY_IQ9_BLR) { new_type = GGML_TYPE_IQ2_S; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) { + new_type = GGML_TYPE_IQ4_NL; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) { + new_type = GGML_TYPE_Q4_K; + } else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { new_type = GGML_TYPE_Q4_0; @@ -17716,14 +17856,17 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { + if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_IQ4_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; } - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) { - new_type = GGML_TYPE_Q4_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K) && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_IQ4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { new_type = GGML_TYPE_Q4_K; @@ -17774,7 +17917,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) { + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K) && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q5_K; } else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && @@ -17792,6 +17935,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n // TODO: explore better strategies new_type = GGML_TYPE_Q8_0; } + else if (qs.model.hparams.n_gqa() >= 4) { + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; + else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S ) new_type = GGML_TYPE_Q4_K; + else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_IQ4_K) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_IQ4_NL) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K; + } ++qs.i_attention_wv; } else if (name.find("attn_k.weight") != std::string::npos) { if (qs.model.hparams.n_expert == 8) { @@ -17971,11 +18121,12 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { if (arch != LLM_ARCH_FALCON) { - if (qs.model.hparams.n_expert == 8) { + if (qs.model.hparams.n_expert >= 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || - ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || - ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || - ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) { + ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || + ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K) { new_type = GGML_TYPE_Q5_K; } else if ( ftype == LLAMA_FTYPE_MOSTLY_IQ3_XSR || ftype == LLAMA_FTYPE_MOSTLY_IQ3_SR || ftype == LLAMA_FTYPE_MOSTLY_IQ3_MR || @@ -18109,7 +18260,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || - new_type == GGML_TYPE_IQ1_M) { + new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN || + new_type == GGML_TYPE_IQ6_K) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -18119,6 +18272,12 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.n_k_quantized; } } + if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN) { + int nx = tensor->ne[0]; + if (nx % QK_IQ1BN != 0) { + convert_incompatible_tensor = true; + } + } if (convert_incompatible_tensor) { switch (new_type) { case GGML_TYPE_IQ2_XXS: @@ -18128,11 +18287,17 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_IQ4_K: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_IQ5_K: case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_IQ6_K: case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); } @@ -18235,8 +18400,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_XS: default_type = GGML_TYPE_IQ1_S; break; + case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; + case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: default_type = GGML_TYPE_IQ2_TN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; + case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; + case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break; + case LLAMA_FTYPE_MOSTLY_IQ5_K: default_type = GGML_TYPE_IQ5_K; break; + case LLAMA_FTYPE_MOSTLY_IQ6_K: default_type = GGML_TYPE_IQ6_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break; @@ -18592,12 +18765,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } } - if ((new_type == GGML_TYPE_IQ2_XXS || + if (!params->ignore_imatrix_rules && !imatrix && + (new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ1_S || (new_type == GGML_TYPE_IQ2_S && strcmp(tensor->name, "token_embd.weight")) || (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { + (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { LLAMA_LOG_ERROR("\n\n============================================================\n"); LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); @@ -18974,6 +19148,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.keep_split =*/ false, + /*.ignore_imatrix_rules =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, }; diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 819b0b2d50113..e6836b5cca741 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -32,6 +33,14 @@ static void generate_data(float offset, size_t n, float * dst) { dst[i] = 0.1 + 2*cosf(i + offset); } } +static void generate_bitnet_data(size_t n, float * dst) { + std::mt19937 rndm(1234); + for (size_t i = 0; i < n; i++) { + auto r = rndm(); + dst[i] = r > std::mt19937::max()/2 ? 0.f : r < std::mt19937::max()/4 ? -1.f : 1.f; + } +} + // Calculate RMSE between two float arrays static float array_rmse(const float * a1, const float * a2, size_t n) { @@ -85,7 +94,7 @@ static float dot_product_error( auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type); - qfns.from_float(test_data1, tmp_q1.data(), test_size); + qfns.from_float_ref(test_data1, tmp_q1.data(), test_size); vdot.from_float(test_data2, tmp_q2.data(), test_size); float result = INFINITY; @@ -114,9 +123,11 @@ int main(int argc, char * argv[]) { std::vector test_data(test_size); std::vector test_data2(test_size); + std::vector test_data_bitnet(test_size); generate_data(0.0, test_data.size(), test_data.data()); generate_data(1.0, test_data2.size(), test_data2.data()); + generate_bitnet_data(test_data_bitnet.size(), test_data_bitnet.data()); // Initialize GGML, ensures float conversion tables are initialized struct ggml_init_params ggml_params = { @@ -138,13 +149,21 @@ int main(int argc, char * argv[]) { continue; } + auto test_data_quantize = test_data.data(); + auto test_data_vecdot = test_data2.data(); const ggml_type ei = (ggml_type)i; + if (ei == GGML_TYPE_IQ1_BN || ei == GGML_TYPE_IQ2_BN) { + test_data_quantize = test_data_bitnet.data(); + test_data_vecdot = test_data_bitnet.data(); + //printf("Skipping %s because test data does not satisfy Bitnet requirements\n", ggml_type_name(ei)); + //continue; + } printf("Testing %s\n", ggml_type_name((ggml_type) i)); ggml_quantize_init(ei); if (qfns.from_float && qfns.to_float) { - const float total_error = total_quantization_error(qfns, test_size, test_data.data()); + const float total_error = total_quantization_error(qfns, test_size, test_data_quantize); const float max_quantization_error = type == GGML_TYPE_Q1_3 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_Q2_2 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : @@ -161,14 +180,14 @@ int main(int argc, char * argv[]) { printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); } - const float reference_error = reference_quantization_error(qfns, test_size, test_data.data()); + const float reference_error = reference_quantization_error(qfns, test_size, test_data_quantize); failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR); num_failed += failed; if (failed || verbose) { printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error); } - const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data()); + const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data_vecdot); const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT