summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c123
1 files changed, 107 insertions, 16 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 88820438..a904464e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -12589,6 +12589,43 @@ static void ggml_compute_forward_repeat_f16(
}
}
+static void ggml_compute_forward_repeat_any(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src = dst->src[0];
+
+ GGML_ASSERT(ggml_can_repeat(src, dst));
+ GGML_ASSERT(src->type == dst->type);
+ GGML_ASSERT(src->nb[0] == ggml_type_size(src->type));
+ int64_t src_row_size = ggml_row_size(src->type, src->ne[0]);
+ GGML_ASSERT((int64_t )dst->nb[1] == src_row_size*dst->ne[0]/src->ne[0]);
+
+ int ith = params->ith;
+ int nth = params->nth;
+
+ int64_t nrows = ggml_nrows(dst);
+ int64_t nrows_per_thread = (nrows + nth - 1)/nth;
+ int64_t first_row = ith*nrows_per_thread;
+ if (first_row >= nrows) return;
+ int64_t last_row = MIN(first_row + nrows_per_thread, nrows);
+
+ for (int64_t row = first_row; row < last_row; ++row) {
+ int64_t i3 = row/(dst->ne[1]*dst->ne[2]);
+ int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1];
+ int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1];
+ char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3];
+ int64_t i03 = i3 % src->ne[3];
+ int64_t i02 = i2 % src->ne[2];
+ int64_t i01 = i1 % src->ne[1];
+ const char * x = (const char *)src->data + i01*src->nb[1] + i02*src->nb[2] + i03*src->nb[3];
+ for (int64_t ir = 0; ir < dst->ne[0]/src->ne[0]; ++ir) {
+ memcpy(y, x, src_row_size);
+ y += src_row_size;
+ }
+ }
+}
+
static void ggml_compute_forward_repeat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -12609,7 +12646,8 @@ static void ggml_compute_forward_repeat(
} break;
default:
{
- GGML_ABORT("fatal error");
+ ggml_compute_forward_repeat_any(params, dst);
+ //GGML_ABORT("fatal error");
}
}
}
@@ -12762,6 +12800,44 @@ static void ggml_compute_forward_concat_f32(
}
}
+static void ggml_compute_forward_concat_any(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
+
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
+ // Let's do it for dim = 0 only for now
+ GGML_ASSERT(dim == 0);
+
+ int ith = params->ith;
+ int nth = params->nth;
+
+ int64_t nrows = ggml_nrows(dst);
+ int64_t nrows_per_thread = (nrows + nth - 1)/nth;
+ int64_t first_row = ith*nrows_per_thread;
+ if (first_row >= nrows) return;
+ int64_t last_row = MIN(first_row + nrows_per_thread, nrows);
+
+ int64_t src0_row_size = ggml_row_size(src0->type, src0->ne[0]);
+ int64_t src1_row_size = ggml_row_size(src1->type, src1->ne[0]);
+
+ for (int64_t row = first_row; row < last_row; ++row) {
+ int64_t i3 = row/(dst->ne[1]*dst->ne[2]);
+ int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1];
+ int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1];
+ char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3];
+ const char * x0 = (const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3];
+ const char * x1 = (const char *)src1->data + i1*src1->nb[1] + i2*src1->nb[2] + i3*src1->nb[3];
+ memcpy(y, x0, src0_row_size);
+ memcpy(y + src0_row_size, x1, src1_row_size);
+ }
+
+}
+
static void ggml_compute_forward_concat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -12776,7 +12852,8 @@ static void ggml_compute_forward_concat(
} break;
default:
{
- GGML_ABORT("fatal error");
+ ggml_compute_forward_concat_any(params, dst);
+ //GGML_ABORT("fatal error");
}
}
}
@@ -14302,7 +14379,17 @@ UseGgmlGemm1:;
const size_t nbw3 = nbw2*ne12;
assert(params->wsize >= ne13*nbw3);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ if (src1->type != GGML_TYPE_F32) {
+#if GGML_USE_IQK_MULMAT
+ char * work_buffer = wdata + ne13*nbw3 + ith*ne10*sizeof(float);
+ GGML_ASSERT(params->wsize >= ne13*nbw3 + nth*ne10*sizeof(float));
+ iqk_quantize_any(src1->type, vec_dot_type, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+ src1->data, wdata, work_buffer, type_traits[src1->type].to_float, from_float, ith, nth);
+#else
+ GGML_ABORT("fatal error");
+#endif
+ }
+ else {
//#ifdef GGML_USE_IQK_MULMAT
// int ts = type_traits[vec_dot_type].type_size;
@@ -14348,6 +14435,7 @@ UseGgmlGemm1:;
}
}
//#endif
+ }
ggml_barrier(params->shared);
@@ -16250,28 +16338,28 @@ static void ggml_compute_forward_soft_max_f32(
}
}
-#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(wp[i]));
- }
-#endif
+//#ifndef NDEBUG
+// for (int i = 0; i < nc; ++i) {
+// //printf("p[%d] = %f\n", i, p[i]);
+// assert(!isnan(wp[i]));
+// }
+//#endif
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
- assert(sum > 0.0);
+ //assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(nc, dp, sum);
-#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- assert(!isnan(dp[i]));
- assert(!isinf(dp[i]));
- }
-#endif
+//#ifndef NDEBUG
+// for (int i = 0; i < nc; ++i) {
+// assert(!isnan(dp[i]));
+// assert(!isinf(dp[i]));
+// }
+//#endif
}
}
@@ -21498,6 +21586,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (node->src[1]->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]);
+ if (node->src[1]->type != GGML_TYPE_F32) {
+ cur += n_tasks*node->src[1]->ne[0]*sizeof(float); // src1->type -> f32 -> vec_dot_type
+ }
}
} break;
case GGML_OP_MUL_MAT_ID: