diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 14 |
1 files changed, 14 insertions, 0 deletions
@@ -394,6 +394,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); +ggml_collect_imatrix_t g_imatrix_collect = NULL; + +void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) { + g_imatrix_collect = imatrix_collect; +} + static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { .type_name = "i8", @@ -9763,6 +9769,10 @@ static void ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0, src1); + } + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -10066,6 +10076,10 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0_cur, src1); + } + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); |