summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c143
1 files changed, 138 insertions, 5 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 7ba5e1ad..31fbc57e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3875,6 +3875,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
+ "ARGSORT_THRESH",
"LEAKY_RELU",
"SOFTCAP",
"SOFT_CAP_MAX",
@@ -3905,7 +3906,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3969,6 +3970,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
+ "argsort_thresh(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
"soft_max(k2*tanh(k1*x))",
@@ -3999,7 +4001,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -8497,6 +8499,27 @@ struct ggml_tensor * ggml_argsort(
return result;
}
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort_thresh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int min_entries,
+ float thresh) {
+ bool is_node = false;
+
+ //printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) min_entries);
+ ggml_set_op_params_f32(result, 1, thresh);
+
+ result->op = GGML_OP_ARGSORT_THRESH;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
// ggml_top_k
@@ -8516,6 +8539,32 @@ struct ggml_tensor * ggml_top_k(
return result;
}
+// ggml_top_k_thresh
+
+struct ggml_tensor * ggml_top_k_thresh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int k,
+ int min_entries,
+ float thresh) {
+ GGML_ASSERT(a->ne[0] >= k);
+
+ //printf("%s: k = %d, min_entries = %d, thresh = %g\n", __func__, k, min_entries, (double)thresh);
+ struct ggml_tensor * result;
+ if (min_entries <= 0 || thresh <= 0) {
+ result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
+ } else {
+ result = ggml_argsort_thresh(ctx, a, min_entries, thresh);
+ }
+
+ result = ggml_view_4d(ctx, result,
+ k, result->ne[1], result->ne[2], result->ne[3],
+ result->nb[1], result->nb[2], result->nb[3],
+ 0);
+
+ return result;
+}
+
// ggml_flash_attn_ext
struct ggml_tensor * ggml_flash_attn_ext(
@@ -14485,7 +14534,8 @@ static void ggml_compute_forward_mul_mat_id(
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
- assert(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
@@ -14737,7 +14787,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
- assert(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
@@ -15580,7 +15631,11 @@ static void ggml_compute_forward_get_rows_q(
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- assert(i01 >= 0 && i01 < ne01);
+ if (i01 < 0 || i01 >= ne01) {
+ memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
+ continue;
+ }
+ //assert(i01 >= 0 && i01 < ne01);
dequantize_row_q(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -17667,6 +17722,75 @@ static void ggml_compute_forward_argsort(
}
}
+// ggml_compute_forward_argsort_thresh
+
+static void ggml_compute_forward_argsort_thresh_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ int min_entries = ggml_get_op_params_i32(dst, 0);
+ float thresh = ggml_get_op_params_f32(dst, 1);
+
+ //if (ith == 0) printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if (src_data[dst_data[j]] < src_data[dst_data[k]]) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ float max_value = src_data[dst_data[0]];
+ //printf("Row %ld: max_value is %g, next is %g\n", i, (double)max_value, (double)src_data[dst_data[1]]);
+ for (int j = min_entries; j < ne0; ++j) {
+ if (src_data[dst_data[j]] < max_value*thresh) {
+ //printf(" row %ld: turning off expert %d(%d) with value %g\n", i, j, dst_data[j], (double)src_data[dst_data[j]]);
+ dst_data[j] = -1;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_argsort_thresh(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_argsort_thresh_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16(
@@ -19476,6 +19600,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_argsort(params, tensor);
} break;
+ case GGML_OP_ARGSORT_THRESH:
+ {
+ ggml_compute_forward_argsort_thresh(params, tensor);
+ } break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
@@ -20461,6 +20589,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
+ case GGML_OP_ARGSORT_THRESH:
+ {
+ GGML_ABORT("fatal error"); // TODO: not implemented
+ }
case GGML_OP_LEAKY_RELU:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -21181,6 +21313,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
+ case GGML_OP_ARGSORT_THRESH:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV: