diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 143 |
1 files changed, 138 insertions, 5 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7ba5e1ad..31fbc57e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3875,6 +3875,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "ARGSORT_THRESH", "LEAKY_RELU", "SOFTCAP", "SOFT_CAP_MAX", @@ -3905,7 +3906,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3969,6 +3970,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "argsort_thresh(x)", "leaky_relu(x)", "k2*tanh(k1*x)", "soft_max(k2*tanh(k1*x))", @@ -3999,7 +4001,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -8497,6 +8499,27 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort + +struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float thresh) { + bool is_node = false; + + //printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) min_entries); + ggml_set_op_params_f32(result, 1, thresh); + + result->op = GGML_OP_ARGSORT_THRESH; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} // ggml_top_k @@ -8516,6 +8539,32 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_top_k_thresh + +struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh) { + GGML_ASSERT(a->ne[0] >= k); + + //printf("%s: k = %d, min_entries = %d, thresh = %g\n", __func__, k, min_entries, (double)thresh); + struct ggml_tensor * result; + if (min_entries <= 0 || thresh <= 0) { + result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + } else { + result = ggml_argsort_thresh(ctx, a, min_entries, thresh); + } + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( @@ -14485,7 +14534,8 @@ static void ggml_compute_forward_mul_mat_id( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -14737,7 +14787,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -15580,7 +15631,11 @@ static void ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 < 0 || i01 >= ne01) { + memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + continue; + } + //assert(i01 >= 0 && i01 < ne01); dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -17667,6 +17722,75 @@ static void ggml_compute_forward_argsort( } } +// ggml_compute_forward_argsort_thresh + +static void ggml_compute_forward_argsort_thresh_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + int min_entries = ggml_get_op_params_i32(dst, 0); + float thresh = ggml_get_op_params_f32(dst, 1); + + //if (ith == 0) printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if (src_data[dst_data[j]] < src_data[dst_data[k]]) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + float max_value = src_data[dst_data[0]]; + //printf("Row %ld: max_value is %g, next is %g\n", i, (double)max_value, (double)src_data[dst_data[1]]); + for (int j = min_entries; j < ne0; ++j) { + if (src_data[dst_data[j]] < max_value*thresh) { + //printf(" row %ld: turning off expert %d(%d) with value %g\n", i, j, dst_data[j], (double)src_data[dst_data[j]]); + dst_data[j] = -1; + } + } + } +} + +static void ggml_compute_forward_argsort_thresh( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_thresh_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -19476,6 +19600,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_ARGSORT_THRESH: + { + ggml_compute_forward_argsort_thresh(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -20461,6 +20589,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_ARGSORT_THRESH: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -21181,6 +21313,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: |