Skip to content

Commit

Permalink
PoC: Chunk weight quantize tool for LLM [skip ci]
Browse files Browse the repository at this point in the history
- Blockwise quantization for LLM: FullyConnected, Gather
- Decide quantize type by circle-quantizer parameter: `--quantize_weights_chunk` (Q4_0, Q8_0)
- Skip quantization by circle-quantizer parameter: `--skip_chunkquant_size` (default: 0)

ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <[email protected]>
  • Loading branch information
hseok-oh committed Oct 11, 2024
1 parent 63d7ff2 commit 750278f
Show file tree
Hide file tree
Showing 21 changed files with 703 additions and 45 deletions.
1 change: 1 addition & 0 deletions compiler/circle-partitioner/src/HelperPath.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef __CIRCLE_HELPER_PATH_H__
#define __CIRCLE_HELPER_PATH_H__

#include <cstdint>
#include <string>

namespace partee
Expand Down
2 changes: 1 addition & 1 deletion compiler/circle-quantizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ if(NOT Jsoncpp_FOUND)
return()
endif(NOT Jsoncpp_FOUND)

set (SOURCES src/CircleQuantizer.cpp)
set (SOURCES src/CircleQuantizer.cpp src/QuantizeWeightsLLM.cpp)

add_executable(circle-quantizer "${SOURCES}")
target_include_directories(circle-quantizer PRIVATE ${Jsoncpp_INCLUDE_DIRS})
Expand Down
82 changes: 82 additions & 0 deletions compiler/circle-quantizer/src/CircleQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
* limitations under the License.
*/

#include "QuantizeWeightsLLM.h"

#include <luci/ImporterEx.h>
#include <luci/CircleQuantizer.h>
#include <luci/Service/Validate.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
#include <luci/UserSettings.h>
#include <luci/IR/CircleNodeDecl.h>

#include <oops/InternalExn.h>
#include <arser/arser.h>
Expand Down Expand Up @@ -151,6 +154,7 @@ void print_exclusive_options(void)
std::cout << " --requantize" << std::endl;
std::cout << " --force_quantparam" << std::endl;
std::cout << " --quantize_weights" << std::endl;
std::cout << " --quantize_llm" << std::endl;
std::cout << " --quantize_onnx_fq_model" << std::endl;
}

Expand All @@ -176,6 +180,8 @@ int entry(int argc, char **argv)
const std::string fake_quant = "--fake_quantize";
const std::string qw = "--quantize_weights";
const std::string cfg = "--config";
const std::string qllm = "--block_quantize_weights";
const std::string skip_qllm = "--skipsize_block_quantize";

const std::string tf_maxpool = "--TF-style_maxpool";

Expand Down Expand Up @@ -221,6 +227,20 @@ int entry(int argc, char **argv)
.help("Convert a quantized model to a fake-quantized model. NOTE: This feature will "
"generate an fp32 model.");

arser.add_argument(qllm)
.nargs(1)
.type(arser::DataType::STR)
.help("FullyConnected weight and Gather param quantization with block granualrity. "
"One argument requires: type(Q4_0, Q8_0)");

arser.add_argument(skip_qllm)
.nargs(1)
.type(arser::DataType::INT32)
.default_value(0)
.help("Skip weight quantization with block granualrity when "
"weight is smaller than specified elementsize. "
"One argument requires: size (default: 0)");

arser.add_argument(rq)
.nargs(2)
.type(arser::DataType::STR_VEC)
Expand Down Expand Up @@ -289,6 +309,7 @@ int entry(int argc, char **argv)
opt_used += arser[cq] ? 1 : 0;
opt_used += arser[fake_quant] ? 1 : 0;
opt_used += arser[qw] ? 1 : 0;
opt_used += arser[qllm] ? 1 : 0;
opt_used += arser.get<bool>(qofm) ? 1 : 0;
if (opt_used != 1)
{
Expand Down Expand Up @@ -465,6 +486,67 @@ int entry(int argc, char **argv)
if (arser[fake_quant])
options->enable(Algorithms::ConvertToFakeQuantizedModel);

if (arser[qllm])
{
std::string input_path = arser.get<std::string>("input");
std::string output_path = arser.get<std::string>("output");
std::string type_str = arser.get<std::string>(qllm);
auto skip_length = arser.get<int32_t>(skip_qllm);
quantizer::QuantizeWeightsLLM::Type qtype = quantizer::QuantizeWeightsLLM::Type::Q4_0;
if (type_str == "Q8_0")
qtype = quantizer::QuantizeWeightsLLM::Type::Q8_0;
else if (type_str == "skip")
qtype = quantizer::QuantizeWeightsLLM::Type::SKIP;
else if (type_str != "Q4_0")
{
std::cerr << "ERROR: Unsupported chunk quantization type" << std::endl;
return 255;
}
if (skip_length < 0)
{
std::cerr << "ERROR: Skip weight elementsize should be larger than zero" << std::endl;
return 255;
}

// Load model from the file
luci::ImporterEx importerex;
auto module = importerex.importVerifyModule(input_path);
if (module.get() == nullptr)
return EXIT_FAILURE;

for (size_t idx = 0; idx < module->size(); ++idx)
{
auto graph = module->graph(idx);

// Weight quantization for LLM
for (auto node : loco::active_nodes(loco::output_nodes(graph)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
quantizer::QuantizeWeightsLLM qw(qtype, skip_length);
circle_node->accept(&qw);
}

if (!luci::validate(graph))
{
std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
return 255;
}
}

// Export to output Circle file
luci::CircleExporter exporter;

luci::CircleFileExpContract contract(module.get(), output_path);

if (!exporter.invoke(&contract))
{
std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
return 255;
}

return 0;
}

if (arser[qw])
{
auto values = arser.get<std::vector<std::string>>(qw);
Expand Down
222 changes: 222 additions & 0 deletions compiler/circle-quantizer/src/QuantizeUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright (c) 2023 Georgi Gerganov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_QUANTIZE_QUANTIZE_UTIL_H
#define LUCI_QUANTIZE_QUANTIZE_UTIL_H

#include <cstdint>
#include <cstddef>
#include <cassert>
#include <cmath>

// Copy from llama.cpp

typedef uint16_t ggml_fp16_t;

#define QK4_0 32
typedef struct
{
ggml_fp16_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;

#define QK8_0 32
typedef struct
{
ggml_fp16_t d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;

union block_q4_0_u {
uint8_t u8[sizeof(block_q4_0)];
block_q4_0 b;
};

union block_q8_0_u {
uint8_t u8[sizeof(block_q8_0)];
block_q8_0 b;
};

static inline uint32_t fp32_to_bits(float f)
{
union {
float as_value;
uint32_t as_bits;
} fp32;
fp32.as_value = f;
return fp32.as_bits;
}

static inline float fp32_from_bits(uint32_t w)
{
union {
uint32_t as_bits;
float as_value;
} fp32;
fp32.as_bits = w;
return fp32.as_value;
}

static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f)
{
const float scale_to_inf = 0x1.0p+112f;
const float scale_to_zero = 0x1.0p-110f;

float base = (fabsf(f) * scale_to_inf) * scale_to_zero;

const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000))
{
bias = UINT32_C(0x71000000);
}

base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
}

#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)

#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))

void quantize_row_q4_0_reference(const float *x, block_q4_0 *y, int k)
{
static const int qk = QK4_0;

assert(k % qk == 0);

const int nb = k / qk;

for (int i = 0; i < nb; i++)
{
float amax = 0.0f; // absolute max
float max = 0.0f;

for (int j = 0; j < qk; j++)
{
const float v = x[i * qk + j];
if (amax < fabsf(v))
{
amax = fabsf(v);
max = v;
}
}

const float d = max / -8;
const float id = d ? 1.0f / d : 0.0f;

y[i].d = GGML_FP32_TO_FP16(d);

for (int j = 0; j < qk / 2; ++j)
{
const float x0 = x[i * qk + 0 + j] * id;
const float x1 = x[i * qk + qk / 2 + j] * id;

const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));

y[i].qs[j] = xi0;
y[i].qs[j] |= xi1 << 4;
}
}
}

size_t ggml_quantize_q4_0(const float *src, void *dst, int n, int k)
{
assert(k % QK4_0 == 0);
const int nb = k / QK4_0;

for (int b = 0; b < n; b += k)
{
block_q4_0 *y = (block_q4_0 *)dst + b / QK4_0;

quantize_row_q4_0_reference(src + b, y, k);

for (int i = 0; i < nb; i++)
{
for (int j = 0; j < QK4_0; j += 2)
{
const uint8_t vi0 = y[i].qs[j / 2] & 0x0F;
const uint8_t vi1 = y[i].qs[j / 2] >> 4;
}
}
}

return (n / QK4_0 * sizeof(block_q4_0));
}

void quantize_row_q8_0_reference(const float *x, block_q8_0 *y, int k)
{
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;

for (int i = 0; i < nb; i++)
{
float amax = 0.0f; // absolute max

for (int j = 0; j < QK8_0; j++)
{
const float v = x[i * QK8_0 + j];
amax = MAX(amax, fabsf(v));
}

const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f / d : 0.0f;

y[i].d = GGML_FP32_TO_FP16(d);

for (int j = 0; j < QK8_0; ++j)
{
const float x0 = x[i * QK8_0 + j] * id;

y[i].qs[j] = roundf(x0);
}
}
}

size_t ggml_quantize_q8_0(const float *src, void *dst, int n, int k)
{
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;

for (int b = 0; b < n; b += k)
{
block_q8_0 *y = (block_q8_0 *)dst + b / QK8_0;

quantize_row_q8_0_reference(src + b, y, k);

for (int i = 0; i < nb; i++)
{
for (int j = 0; j < QK8_0; ++j)
{
const int8_t vi = y[i].qs[j];
}
}
}

return (n / QK8_0 * sizeof(block_q8_0));
}

#endif // LUCI_QUANTIZE_QUANTIZE_UTIL_H
Loading

0 comments on commit 750278f

Please sign in to comment.