diff --git a/common/arg.cpp b/common/arg.cpp index 98baac4c14da2..b7e9a639bb49d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2749,6 +2749,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.show_statistics = true; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--activation-statistics"}, + string_format("generate data to compute activation-based statistics (default: %s)", params.show_statistics ? "true" : "false"), + [](common_params & params) { + params.activation_statistics = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"--parse-special"}, string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), diff --git a/common/common.h b/common/common.h index 75596e6b32979..153ba8d701fd4 100644 --- a/common/common.h +++ b/common/common.h @@ -473,10 +473,11 @@ struct common_params { int32_t i_chunk = 0; // start processing from this chunk int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat) - bool process_output = false; // collect data for the output tensor - bool compute_ppl = true; // whether to compute perplexity - bool show_statistics = false; // show imatrix statistics per tensor - bool parse_special = false; // whether to parse special tokens during imatrix tokenization + bool process_output = false; // collect data for the output tensor + bool compute_ppl = true; // whether to compute perplexity + bool show_statistics = false; // show imatrix statistics per tensor + bool activation_statistics = false; // generate data to calculate activation based statistics + bool parse_special = false; // whether to parse special tokens during imatrix tokenization // cvector-generator params int n_pca_batch = 100; diff --git a/tools/imatrix/README.md b/tools/imatrix/README.md index 4505cb4ce8c7d..c25fa097a9dfb 100644 --- a/tools/imatrix/README.md +++ b/tools/imatrix/README.md @@ -10,7 +10,7 @@ More information is available in =2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`. * `-o | --output-file` specifies the name of the file where the computed data will be stored. If missing `imatrix.gguf` is used. * `-ofreq | --output-frequency` specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks) -* `--output-format` specifies the output format of the generated imatrix file. Either "gguf", or "dat" (the legacy format). Defaults to "gguf". +* `--output-format` specifies the output format of the generated imatrix file. Either `gguf`, or `dat` (the legacy format). Defaults to `gguf`. * `--save-frequency` specifies how often to save a copy of the imatrix in a separate file. Default is 0 (i.e., never) * `--process-output` specifies if data will be collected for the `output.weight` tensor. Typically, it is better not to utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default. * `--in-file` one or more existing imatrix files to load and combine. Useful for merging files from multiple runs/datasets. * `--parse-special` enables parsing of special tokens (e.g., `<|im_start|>` in some models). Useful for models with custom tokenizers. * `--chunk | --from-chunk` to skip the first `n` chunks of tokens from the input data. Useful for resuming or skipping initial low-quality data. -* `--chunks` maximum number of chunks to process. Default is -1 for all available chunks. +* `--chunks` maximum number of chunks to process. Default is `-1` for all available chunks. * `--no-ppl` disables the calculation of perplexity for the processed chunks. Useful if you want to speed up the processing and do not care about perplexity. * `--show-statistics` displays imatrix file's statistics. +* `--activation-statistics` enables the collection of activation statistics for each tensor. If set, the imatrix file size will double, but reported statistics will be more accurate. For faster computation, make sure to use GPU offloading via the `-ngl | --n-gpu-layers` argument. -Recent versions of `llama-imatrix` store data in GGUF format by default. For the legacy format, use an extension other than `.gguf` when saving the output file. More information is available in . +Versions **b5942** and newer of `llama-imatrix` store data in GGUF format by default. For the legacy format, use `--output-format dat` when saving the output file. More information is available in . ## Examples @@ -69,30 +70,37 @@ Recent versions of `llama-imatrix` store data in GGUF format by default. For the ./llama-imatrix -m ggml-model-f16.gguf -f calibration-data.txt --chunk 5 --output-frequency 20 --save-frequency 50 --parse-special ``` +```bash +# generate imatrix and enable activation-based statistics +./llama-imatrix -m ggml-model-f16.gguf -f calibration-data.txt --activation-statistics -ngl 99 +``` + ```bash # analyse imatrix file and display summary statistics instead of running inference ./llama-imatrix --in-file imatrix.gguf --show-statistics ``` -`--show-statistics` will display the following statistics: +## Statistics + +For current versions of `llama-imatrix`, the `--show-statistics` option has two modes of operation: If `--activation-statistics` was used to generate the imatrix and `--output-format` was set to `gguf`, precise activations statistics will be calculated. Otherwise, it will report less accurate, albeit still useful, metrics based on average squared activations. #### Per tensor -* Σ(Act²): sum of all squared activations (the importance scores) -* Min & Max: minimum and maximum squared activations values -* μ & σ: Squared activations' mean and standard deviation -* % Active: proportion of elements whose average squared activation exceeds a small threshold (1e-5). Helpful to determine how alive/dormant the tensor is during inference -* N: number of squared activations -* Entropy: entropy of the squared activation distribution, in bits (standard Shannon entropy measurement) $S = -\sum_{i=1}^N p_i \log_2 p_i$ -* E (norm): Normalized entropy. $E(norm)=\frac{-\sum_{i=1}^N p_i \log_2 p_i}{log_2 N}$. These two metrics can be used to determine how well a prompt "exercises" the model's capabilities -* ZD Score: z-score distribution as described in _3.1 Layer Importance Scores_ of [Layer-Wise Quantization](https://arxiv.org/abs/2406.17415) -* CosSim: cosine similarity with respect to the previous layer's tensor. Useful to determine how similar the squared activations of the current layer are to the previous layer's squared activations. +* **Σ(Act²)** *(legacy mode)* / **L₂ Norm** *(preferred)*: If in legacy mode, the raw sum of squares of activations (sum of `Act²`). In preferred mode, the Euclidean Distance (L₂ Norm) between this tensor’s average activations and those of the previous layer. +* **Min / Max / μ / σ**: Tensor elements Min, Max, Mean, and Standard Deviation. +* **N**: Number of tensor elements considered. +* **H Norm**: Shannon Entropy normalized over log₂(N). Defined as $H Norm=\frac{-\sum_{i=1}^N p_i \log_2 p_i}{log_2 N}$. Used to determine how well a prompt "exercises" the model's capabilities. +* **H** *(legacy mode)* / **ECS** *(preferred)*: If legacy, Shannon Entropy defined as $H = -\sum_{i=1}^N p_i \log_2 p_i$. If preferred, *Euclidean-Cosine Score* defined as $ECS = K \cdot e^{-\alpha a} \cdot |b|^{\gamma}$ where `a = L₂ Norm`, `b = Cosine Similarity`, `α = 0.01`, `γ = 10` between this tensor’s elements and those of the previous layer. Higher score means more similarity and lower change. +* **ZD**: % of elements whose Z-score is > 1.0 in magnitude (an indicator of outliers), as described in _3.1 Layer Importance Scores_ of [Layer-Wise Quantization](https://arxiv.org/abs/2406.17415) +* **CosSim**: Cosine Similarity between this tensor’s elements and those of the previous layer. #### Per layer -Weighted averages of Σ(Act²), ZD Score and CosSim are also calculated. +Aggregated metrics per block/layer: -#### Important note on the computed Statistics +* **Σ(Act²)** *(legacy mode)* / **L₂ Norm** *(preferred)*: If in legacy mode, the sum of squared activations (sum of Act²) for the layer's concatenated tensors. In preferred mode, the Euclidean Distance (L₂ Norm) between this layer's average concatenated tensor activations the previous layer. +* **ZD**: % of this layer's concatenated tensors' elements with |Z| > 1. +* **CosSim**: Cosine Similarity between this layer's concatenated tensors' elements compared and the previous layer’s. +* **ECS** *(preferred only)*: Euclidean-Cosine Score applied to the layer. -When using these statistics, please note that they are computed on the squared activations, **not on the actual (raw) activations**. -Whilst the results are still useful, they're less realiable than using the raw values, and in the case of the cosine similarity, could be misleading if the tensor contains opposite vectors. +More information is available in https://github.com/ggml-org/llama.cpp/pull/14891 diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index f28a036deebe3..e1c962f5beb56 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -1,8 +1,8 @@ #include "arg.h" #include "common.h" -#include "log.h" -#include "llama.h" #include "gguf.h" +#include "llama.h" +#include "log.h" #include #include @@ -10,14 +10,15 @@ #include #include #include -#include -#include -#include #include -#include #include -#include +#include #include +#include +#include +#include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -29,7 +30,7 @@ static void print_usage(int, char ** argv) { " -m model.gguf -f some-text.txt [-o imatrix.gguf] [--output-format {gguf,dat}] [--no-ppl] \\\n" " [--process-output] [--chunk 123] [--save-frequency 0] [--output-frequency 10] \\\n" " [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...] [--parse-special] \\\n" - " [--show-statistics] [...]\n" , argv[0]); + " [--activation-statistics] [--show-statistics] [...]\n" , argv[0]); LOG("\n"); } @@ -38,6 +39,7 @@ static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; struct Stats { + std::vector activations; std::vector values; std::vector counts; }; @@ -45,22 +47,23 @@ struct Stats { struct tensor_statistics { std::string tensor; Stats stats; - float total_sqract = 0.0f; - float mean_sqract = 0.0f; - float max_sqract = 0.0f; - float min_sqract = 0.0f; - int elements = 0; - float stddev = 0.0f; - float active = 0.0f; - float entropy = 0.0f; - float zd = 0.0f; - float cossim = 0.0f; + float sum_values = 0.0f; + float mean_values = 0.0f; + float max_values = 0.0f; + float min_values = 0.0f; + int elements = 0; + float std_deviation = 0.0f; + float entropy = 0.0f; + float zd_score = 0.0f; + float cossim = 0.0f; + float l2_norm = 0.0f; }; class IMatrixCollector { public: IMatrixCollector() = default; void set_params(common_params params) { m_params = std::move(params); } + bool activation_statistics() const { return m_params.activation_statistics; } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix_legacy(int32_t ncall = -1) const; void save_imatrix(int32_t n_chunk = -1) const; @@ -125,95 +128,241 @@ static void process_tensor_name(const std::string & input, std::string & layer, } } -static void compute_statistics(std::vector & tstats, const std::string & name, const Stats & e) { +static std::vector compute_tensor_averages(const Stats & tstats) { + if (tstats.counts.empty()) return {}; + const size_t n_mat = tstats.counts.size(); + const size_t len = !tstats.activations.empty() ? tstats.activations.size() : tstats.values.size(); + + if (len == 0 || len % n_mat != 0) return {}; + const size_t row = len / n_mat; + std::vector vec; + vec.reserve(len); + + if (tstats.activations.empty()) { + for (size_t m = 0; m < n_mat; ++m) { + const float c = (float)tstats.counts[m]; + if (c <= 0) return {}; + const size_t off = m * row; + for (size_t j = 0; j < row; ++j) { + vec.push_back(tstats.values[off + j] / c); + } + } + } else { + for (size_t m = 0; m < n_mat; ++m) { + const float c = (float)tstats.counts[m]; + if (c <= 0) return {}; + const size_t off = m * row; + for (size_t j = 0; j < row; ++j) { + vec.push_back(tstats.activations[off + j] / c); + } + } + } + + return vec; +} + +static bool compute_vector_statistics(std::vector & tstats, const std::string & name, const Stats & e) { if (e.values.size() % e.counts.size() != 0) { LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.values.size()); - return; + return false; } if (e.counts.empty()) { LOG_ERR("%s: there are no activations for tensor %s. The imatrix may be suboptimal\n", __func__, name.c_str()); - return; + return false; } const int n_mat = e.counts.size(); const int row_size = e.values.size() / n_mat; std::vector activations; - activations.reserve(e.values.size()); - for (int i = 0; i < n_mat; ++i) { - for (int j = 0; j < row_size; ++j) { - activations.push_back(e.values[i*row_size + j] / e.counts[i]); + if (e.activations.empty()) { + activations.reserve(e.values.size()); + + for (int i = 0; i < n_mat; ++i) { + for (int j = 0; j < row_size; ++j) { + activations.push_back(e.values[i*row_size + j] / e.counts[i]); + } + } + } else { + activations.reserve(e.activations.size()); + + for (int i = 0; i < n_mat; ++i) { + for (int j = 0; j < row_size; ++j) { + activations.push_back(e.activations[i*row_size + j] / e.counts[i]); + } } } - const float act_total = std::accumulate(activations.begin(), activations.end(), 0.0f); - const float act_max = *std::max_element(activations.begin(), activations.end()); - const float act_min = *std::min_element(activations.begin(), activations.end()); - const float act_mean = act_total / activations.size(); - const float act_sqr_total = std::inner_product(activations.begin(), activations.end(), activations.begin(), 0.0f); - const float act_var = (act_sqr_total / activations.size()) - (act_mean * act_mean); - const float act_dev = std::sqrt(std::max(0.0f, act_var)); - float threshold = 1e-5f; - const int inactive_count = std::count_if(activations.begin(), activations.end(), - [threshold](const float v) { return fabsf(v) <= threshold; }); - const float active_ratio = 1 - static_cast(inactive_count) / activations.size(); + const float sum = std::accumulate(activations.begin(), activations.end(), 0.0f); + const float max = *std::max_element(activations.begin(), activations.end()); + const float min = *std::min_element(activations.begin(), activations.end()); + const float mean = sum / activations.size(); + const float sqr_sum = std::inner_product(activations.begin(), activations.end(), activations.begin(), 0.0f); + const float variance = (sqr_sum / activations.size()) - (mean * mean); + const float std_deviation = std::sqrt(std::max(0.0f, variance)); + float entropy = 0; + + if (e.activations.empty()) { + if (sum > 0) { + for (const auto act : activations) { + if (const float p = act / sum; p > 0) { + entropy -= p * std::log2(p); + } + } + } + } else { + float div = 0.0; + std::vector weights(activations.size()); + for (size_t i = 0; i < activations.size(); ++i) { + const float w = activations[i] * activations[i]; + weights[i] = w; + div += w; + } - float entropy = 0; - if (act_total > 0) { - for (const auto act : activations) { - if (const float p = act / act_total; p > 0) { - entropy -= p * std::log2(p); + if (div > 0.0) { + for (float w : weights) { + const float p = w / div; + if (p > 0.0) entropy -= p * std::log2(p); } } } - int z_score = 0; - if (act_dev > 0.0f) { + int zd_score = 0; + if (std_deviation > 0.0f) { for (const auto act : activations) { - if (const float p = (act - act_mean) / act_dev; p > 1) { - z_score++; - } + if (const float z = (act - mean) / std_deviation; std::fabs(z) > 1.0f) zd_score++; } } auto & ts = tstats.emplace_back(); - ts.tensor = name; - ts.stats = e; - ts.total_sqract = act_total; - ts.mean_sqract = act_mean; - ts.max_sqract = act_max; - ts.min_sqract = act_min; - ts.elements = static_cast(activations.size()); - ts.stddev = act_dev; - ts.active = active_ratio; - ts.entropy = entropy; - ts.zd = static_cast(z_score) / ts.elements; + ts.tensor = name; + ts.stats = e; + ts.sum_values = sum; + ts.mean_values = mean; + ts.max_values = max; + ts.min_values = min; + ts.elements = static_cast(activations.size()); + ts.std_deviation = std_deviation; + ts.entropy = entropy; + ts.zd_score = static_cast(zd_score) / ts.elements; + + return e.activations.empty(); } -static void compute_cossim(std::vector & tstats) { +static void compute_tensor_statistics(std::vector & tstats) { static const std::regex pattern(R"(blk\.(\d+)\.)"); + + // compute the Cosine Similarity between the same tensors in consecutive layers for (auto & ts : tstats) { + ts.cossim = 0; + if (std::smatch match; std::regex_search(ts.tensor, match, pattern)) { const int blk = std::stoi(match[1]); + if (blk <= 0) continue; std::string tname(ts.tensor); tname.replace(match.position(1), match.length(1), std::to_string(blk-1)); auto prev = std::find_if(tstats.begin(), tstats.end(), [tname](const tensor_statistics & t) { return t.tensor == tname; }); - if (prev != tstats.end()) { - const float dp = std::inner_product(ts.stats.values.begin(), ts.stats.values.end(), - prev->stats.values.begin(), 0.0f); - const float curr_mag = std::sqrt(std::inner_product(ts.stats.values.begin(), ts.stats.values.end(), - ts.stats.values.begin(), 0.0f)); - const float prev_mag = std::sqrt(std::inner_product(prev->stats.values.begin(), prev->stats.values.end(), - prev->stats.values.begin(), 0.0f)); - const float cs = dp / (curr_mag * prev_mag); - ts.cossim = cs; + if (prev == tstats.end()) continue; + const auto curr_avg = compute_tensor_averages(ts.stats); + const auto prev_avg = compute_tensor_averages(prev->stats); + if (curr_avg.size() == prev_avg.size() && !curr_avg.empty()) { + float dot_prod = 0.0f, vec1 = 0.0f, vec2 = 0.0f; + for (size_t i = 0; i < curr_avg.size(); ++i) { + dot_prod += curr_avg[i] * prev_avg[i]; + vec1 += curr_avg[i] * curr_avg[i]; + vec2 += prev_avg[i] * prev_avg[i]; + } + if (vec1 > 0 && vec2 > 0) ts.cossim = dot_prod / (std::sqrt(vec1) * std::sqrt(vec2)); } - } else { - ts.cossim = 0; } } + + // compute the L2 Norm (Euclidian Distance) between the same tensors in consecutive layers + for (auto & ts : tstats) { + ts.l2_norm = 0.0f; + if (ts.stats.activations.empty()) continue; + + if (std::smatch match; std::regex_search(ts.tensor, match, pattern)) { + const int blk = std::stoi(match[1]); + if (blk <= 0) continue; + std::string tname(ts.tensor); + tname.replace(match.position(1), match.length(1), std::to_string(blk - 1)); + auto prev = std::find_if(tstats.begin(), tstats.end(), + [tname](const tensor_statistics & t) { return t.tensor == tname; }); + if (prev == tstats.end()) continue; + const auto cur_avg = compute_tensor_averages(ts.stats); + const auto prev_avg = compute_tensor_averages(prev->stats); + if (cur_avg.empty() || prev_avg.empty() || cur_avg.size() != prev_avg.size()) continue; + + float dist = 0.0; + for (size_t i = 0; i < cur_avg.size(); ++i) { + const float act = cur_avg[i] - prev_avg[i]; + dist += act * act; + } + ts.l2_norm = std::sqrt(dist); + } + } +} + +static void compute_layer_statistics(const std::vector & tstats, + std::map & layer_cossim, + std::map & layer_l2_norm, + const std::unordered_map & stats_map) { + struct layer_aggregation { + std::vector curr_avg; + std::vector prev_avg; + }; + static const std::regex pattern(R"(blk\.(\d+)\.)"); + std::unordered_map tidx; + tidx.reserve(tstats.size()); + for (const auto & ts : tstats) tidx[ts.tensor] = &ts; + std::map taggr; + + for (const auto & ts : tstats) { + std::smatch match; + if (!std::regex_search(ts.tensor, match, pattern)) continue; + const int blk = std::stoi(match[1]); + if (blk <= 0) continue; + std::string prev_lyr(ts.tensor); + prev_lyr.replace(match.position(1), match.length(1), std::to_string(blk-1)); + if (auto it_prev = tidx.find(prev_lyr); it_prev == tidx.end()) continue; + const auto curr_avg = compute_tensor_averages(stats_map.at(ts.tensor)); + const auto prev_avg = compute_tensor_averages(stats_map.at(prev_lyr)); + if (curr_avg.empty() || prev_avg.empty() || curr_avg.size() != prev_avg.size()) continue; + auto & [curr, prev] = taggr[blk]; + curr.insert(curr.end(), curr_avg.begin(), curr_avg.end()); + prev.insert(prev.end(), prev_avg.begin(), prev_avg.end()); + } + + // compute the aggregated Cosine Similarity between consecutive layers + for (auto & kv : taggr) { + const auto & curr = kv.second.curr_avg; + const auto & prev = kv.second.prev_avg; + if (curr.size() != prev.size() || curr.empty()) continue; + float dot_prod = 0.0, lyr1 = 0.0, lyr2 = 0.0; + for (size_t i = 0; i < curr.size(); ++i) { + dot_prod += curr[i] * prev[i]; + lyr1 += curr[i] * curr[i]; + lyr2 += prev[i] * prev[i]; + } + float cossim = 0.0f; + if (lyr1 > 0.0 && lyr2 > 0.0) cossim = dot_prod / (std::sqrt(lyr1) * std::sqrt(lyr2)); + layer_cossim[kv.first] = cossim; + } + + // compute the aggregated L2 Norm (Euclidian Distance) between consecutive layers + for (auto & kv : taggr) { + const auto & curr = kv.second.curr_avg; + const auto & prev = kv.second.prev_avg; + if (curr.size() != prev.size() || curr.empty()) continue; + float dist = 0.0f; + for (size_t i = 0; i < curr.size(); ++i) { + dist += (curr[i] - prev[i]) * (curr[i] - prev[i]); + } + layer_l2_norm[kv.first] = std::sqrt(dist); + } } bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { @@ -281,6 +430,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts.resize(n_as, e.counts[0]); } if (e.values.empty()) { + if (activation_statistics()) e.activations.resize(src1->ne[0]*n_as, 0); e.values.resize(src1->ne[0]*n_as, 0); e.counts.resize(n_as, 0); } @@ -312,6 +462,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts[ex]++; for (int64_t j = 0; j < src1->ne[0]; ++j) { + if (activation_statistics()) e.activations[e_start + j] += x[j]; e.values[e_start + j] += x[j] * x[j]; if (!std::isfinite((float)e.values[e_start + j])) { LOG_ERR("%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); @@ -351,6 +502,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } } if (e.values.empty()) { + if (activation_statistics()) e.activations.resize(src1->ne[0] * n_mat, 0); e.values.resize(src1->ne[0] * n_mat, 0); e.counts.resize(1, 0); } @@ -369,6 +521,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int64_t row = 0; row < src1->ne[1]; ++row) { const float * x = (const float *) (data + row * src1->nb[1] + i2 * src1->nb[2] + i3 * src1->nb[3]); for (int64_t j = 0; j < src1->ne[0]; ++j) { + if (activation_statistics()) e.activations[mat_start + j] += x[j]; e.values[mat_start + j] += x[j] * x[j]; if (!std::isfinite((float)e.values[j])) { LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str()); @@ -550,6 +703,7 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { } to_store.push_back(kv.first); + if (activation_statistics()) data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.activations.size(), GGML_MEM_ALIGN); data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); } @@ -602,6 +756,16 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { gguf_add_tensor(ctx_gguf, in_sum2); gguf_add_tensor(ctx_gguf, counts); + + if (!stat.activations.empty() && activation_statistics()) { + const int32_t nact = (int32_t) stat.activations.size(); + struct ggml_tensor * in_sum = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nact / nmat, nmat); + ggml_format_name(in_sum, "%s.in_sum", name.c_str()); + for (int32_t j = 0; j < nval; ++j) { + ((float *) in_sum->data)[j] = (float) stat.activations[j]; + } + gguf_add_tensor(ctx_gguf, in_sum); + } } } @@ -740,6 +904,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { } } + const std::string in_sum_suffix{ ".in_sum" }; const std::string in_sum2_suffix{ ".in_sum2" }; const std::string counts_suffix{ ".counts" }; @@ -747,7 +912,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { // checking for completeness of *each* loaded imatrix file // and also makes it easier to re-use a similar implementation in quantize.cpp // Using an ordered map to get a deterministic iteration order. - std::map> sums_counts_for; + std::map> sums_counts_for; for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { std::string name = cur->name; @@ -756,19 +921,24 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { if (string_remove_suffix(name, in_sum2_suffix)) { // in_sum2 - sums_counts_for[std::move(name)].first = cur; + std::get<0>(sums_counts_for[std::move(name)]) = cur; } else if (string_remove_suffix(name, counts_suffix)) { // counts - sums_counts_for[std::move(name)].second = cur; - } else { + std::get<1>(sums_counts_for[std::move(name)]) = cur; + } else if (string_remove_suffix(name, in_sum_suffix)) { + // in_sum + std::get<2>(sums_counts_for[std::move(name)]) = cur; + } + else { // ignore other tensors } } for (const auto & sc : sums_counts_for) { const std::string & name = sc.first; - const struct ggml_tensor * in_sum2 = sc.second.first; - const struct ggml_tensor * counts = sc.second.second; + const struct ggml_tensor * in_sum2 = std::get<0>(sc.second); + const struct ggml_tensor * counts = std::get<1>(sc.second); + const struct ggml_tensor * in_sum = std::get<2>(sc.second); if (!in_sum2 || !counts) { LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str()); @@ -782,6 +952,9 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { int64_t nval = ggml_nelements(in_sum2); if (e.values.empty()) { e.values.resize(nval, 0.0f); + if (in_sum != nullptr) { + e.activations.resize(nval, 0.0f); + } } else if ((size_t) nval != e.values.size()) { LOG_ERR("%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size()); gguf_free(ctx_gguf); @@ -809,6 +982,11 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { for (int64_t j = 0; j < ncounts; j++) { e.counts[j] += std::lround(((const float *) counts->data)[j]); } + if (in_sum != nullptr) { + for (int64_t j = 0; j < nval; j++) { + e.activations[j] += ((const float *) in_sum->data)[j]; + } + } } // TODO: extract into its own method; this is also used by the legacy format @@ -1083,51 +1261,65 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c static bool show_statistics(const common_params & params) { std::vector ts; + bool legacy_mode = false; + if (params.in_files.empty() || params.in_files.size() > 1) { LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n"); return false; } if (g_collector.load_imatrix(params.in_files[0].c_str())) { for (const auto & [name, stats] :g_collector.get_mstats()) { - compute_statistics(ts, name, stats); + legacy_mode = compute_vector_statistics(ts, name, stats); } } else { LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str()); return false; } if (!ts.empty()) { - compute_cossim(ts); + compute_tensor_statistics(ts); } else { LOG_ERR("Error: cannot compute statistics for %s\n\n", params.in_files[0].c_str()); return false; } struct tensor_comparer { + bool legacy_mode; + explicit tensor_comparer(const bool legacy) : legacy_mode(legacy) {} + bool operator()(const tensor_statistics & a, const tensor_statistics & b) const { std::string layer, name_a, name_b; - ; process_tensor_name(a.tensor, layer, name_a); process_tensor_name(b.tensor, layer, name_b); - return name_a < name_b || (name_a == name_b && a.total_sqract > b.total_sqract); + return legacy_mode ? name_a < name_b || (name_a == name_b && a.sum_values > b.sum_values) : + name_a < name_b || (name_a == name_b && a.cossim > b.cossim); } }; - std::sort(ts.begin(), ts.end(), tensor_comparer()); + std::sort(ts.begin(), ts.end(), tensor_comparer(legacy_mode)); - struct weighted_stats { - float weighted_bias = 0.0f; - float weighted_zd = 0.0f; - float weighted_cossim = 0.0f; - int total_elements = 0; + struct layer_stats { + float lyr_sum = 0.0f; + float lyr_zd = 0.0f; + int n = 0; }; - std::map ws; - - LOG_INF("\nComputing statistics for %s (%d tensors)\n", params.in_files[0].c_str(), static_cast(ts.size())); - LOG_INF("\n%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n", " Layer", " Tensor", " Σ(Act²)", - " Min", " Max", " μ", " σ", " % Active", "N", " Entropy", "E (norm)", "ZD", - " CosSim"); + std::map ls; + + LOG_INF("\nComputing tensor statistics for %s (%d tensors)\n", params.in_files[0].c_str(), static_cast(ts.size())); + LOG_INF("\n%6s\t%18s\t%13s\t%8s\t%8s\t%7s\t%15s\t%13s\t%11s\t%8s\t%5s\t%10s\n", + "Layer", + "Tensor", + legacy_mode ? "Σ(Act²)" : "L₂ Norm", + "Min", + "Max", + "μ", + "σ", + "N", + "H Norm", + legacy_mode ? "H" : "ECS", + "ZD", + "CosSim"); LOG_INF( "==============================================================================================================" - "===========================================================\n"); + "=============================================================\n"); for (const auto & tstat : ts) { std::string layer, name; process_tensor_name(tstat.tensor, layer, name); @@ -1139,48 +1331,71 @@ static bool show_statistics(const common_params & params) { blk = -1; // not a block layer } - LOG_INF("%5s\t%-20s\t%10.2f\t%8.4f\t%11.4f\t%6.2f\t%6.2f\t%8.2f%%\t%6d\t%10.4f\t%6.2f%%\t%10.2f%%\t%8.4f\n", - layer.c_str(), name.c_str(), tstat.total_sqract, tstat.min_sqract, tstat.max_sqract, tstat.mean_sqract, - tstat.stddev, tstat.active * 100.0f, tstat.elements, tstat.entropy, - 100.0f * (tstat.entropy / std::log2(tstat.elements)), 100.0f * tstat.zd, tstat.cossim); - - const float weighted_bias = tstat.elements * tstat.total_sqract; - const float weighted_zd = tstat.elements * tstat.zd; - const float weighted_cossim = tstat.elements * tstat.cossim; - - if (ws.find(blk) != ws.end()) { - ws[blk].weighted_bias += weighted_bias; - ws[blk].weighted_zd += weighted_zd; - ws[blk].weighted_cossim += weighted_cossim; - ws[blk].total_elements += tstat.elements; + LOG_INF("%5s\t%-20s\t%11.2f\t%10.4f\t%10.4f\t%8.2f\t%8.2f\t%7d\t%10.2f%%\t%10.4f\t%6.2f%%\t%10.4f\n", + layer.c_str(), + name.c_str(), + legacy_mode ? tstat.sum_values : tstat.l2_norm, + tstat.min_values, + tstat.max_values, + tstat.mean_values, + tstat.std_deviation, + tstat.elements, + 100.0f * (tstat.entropy / std::log2(tstat.elements)), + legacy_mode ? tstat.entropy : 100.0f * std::exp(-0.01f * tstat.l2_norm) * std::pow(fabs(tstat.cossim), 10.0f), + 100.0f * tstat.zd_score, + tstat.cossim); + + const float zd = tstat.elements * tstat.zd_score; + + if (ls.find(blk) != ls.end()) { + ls[blk].lyr_sum += tstat.sum_values; + ls[blk].lyr_zd += zd; + ls[blk].n += tstat.elements; } else { - weighted_stats temp_ws; - temp_ws.weighted_bias = weighted_bias; - temp_ws.weighted_zd = weighted_zd; - temp_ws.weighted_cossim = weighted_cossim; - temp_ws.total_elements = tstat.elements; - ws[blk] = temp_ws; + layer_stats temp_ls; + temp_ls.lyr_sum = tstat.sum_values; + temp_ls.lyr_zd = zd; + temp_ls.n = tstat.elements; + ls[blk] = temp_ls; } } - const int layers = std::count_if(ws.begin(), ws.end(), [](const auto & kv) { return kv.first >= 0; }); - LOG_INF("\nComputing weighted average statistics per layer (%d layers)\n", layers); - LOG_INF("\n%s\t%s\t%s\t%s\n", " Layer", " μΣ(Act²)", " μZD", "μCosSim"); - LOG_INF("================================================\n"); - for (const auto & [first, second] : ws) { - const auto & layer = first; - const auto & stats = second; - - if (stats.total_elements == 0) { - continue; - } - - if (layer >= 0) { - const float bias = stats.weighted_bias / stats.total_elements; - const float zd = stats.weighted_zd / stats.total_elements; - const float cossim = stats.weighted_cossim / stats.total_elements; - - LOG_INF("%5d\t%14.2f\t%10.4f%%\t%6.4f\n", layer, bias, 100.0f * zd, cossim); + std::map lyr_cossim; + std::map lyr_l2_norm; + compute_layer_statistics(ts, lyr_cossim, lyr_l2_norm, g_collector.get_mstats()); + + const auto layers = std::count_if(ls.begin(), ls.end(), [](const auto & kv) { return kv.first >= 0; }); + LOG_INF("\nComputing layer statistics (%ld layers)\n", layers); + LOG_INF("\n%6s\t%13s\t%6s\t%11s\t%6s\n", + "Layer", + legacy_mode ? "Σ(Act²)" : "L₂ Norm", + "ZD", + "CosSim", + legacy_mode ? "" : "ECS"); + if (legacy_mode) { + LOG_INF("============================================\n"); + } else { + LOG_INF("=========================================================\n"); + } + for (const auto & [layer, stats] : ls) { + if (layer < 0 || stats.n == 0) continue; + const auto lcs = lyr_cossim.find(layer); + const float lyr_cs = lcs != lyr_cossim.end() ? lcs->second : 0.0f; + const auto ll2n = lyr_l2_norm.find(layer); + const float lyr_l2n = ll2n != lyr_l2_norm.end() ? ll2n->second : 0.0f; + if (legacy_mode) { + LOG_INF("%5d\t%11.2f\t%6.2f%%\t%11.4f\n", + layer, + stats.lyr_sum, + 100.0f * stats.lyr_zd / stats.n, + lyr_cs); + } else { + LOG_INF("%5d\t%11.2f\t%6.2f%%\t%11.4f\t%8.4f\n", + layer, + lyr_l2n, + 100.0f * stats.lyr_zd / stats.n, + lyr_cs, + 100.0f * std::exp(-0.01f * lyr_l2n) * std::pow(fabs(lyr_cs), 10.0f)); } } LOG_INF("\n");