summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-01-12 06:59:57 +0100
committerGitHub <noreply@github.com>2024-01-12 06:59:57 +0100
commit326b418b59b6d48d854c4461a2303e8ac0a311e6 (patch)
tree7c0d22f95b48183e685af4f70facecbe6b5963f5 /ggml.c
parent1d118386fea031f01550f8cd47a5c86296e5333f (diff)
Importance Matrix calculation (#4861)
* imatrix: 1st version * imatrix: WIP * Cleanup * Update examples/imatrix/imatrix.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c14
1 files changed, 14 insertions, 0 deletions
diff --git a/ggml.c b/ggml.c
index d2a8c047..f5caeba0 100644
--- a/ggml.c
+++ b/ggml.c
@@ -394,6 +394,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
+ggml_collect_imatrix_t g_imatrix_collect = NULL;
+
+void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) {
+ g_imatrix_collect = imatrix_collect;
+}
+
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
.type_name = "i8",
@@ -9763,6 +9769,10 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith;
const int nth = params->nth;
+ if (ith == 1 && g_imatrix_collect) {
+ g_imatrix_collect(src0, src1);
+ }
+
const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
@@ -10066,6 +10076,10 @@ static void ggml_compute_forward_mul_mat_id(
const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
+ if (ith == 1 && g_imatrix_collect) {
+ g_imatrix_collect(src0_cur, src1);
+ }
+
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);