summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c63
1 files changed, 40 insertions, 23 deletions
diff --git a/ggml.c b/ggml.c
index 5145ceec..023077ca 100644
--- a/ggml.c
+++ b/ggml.c
@@ -4882,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back(
// ggml_concat
struct ggml_tensor * ggml_concat(
- struct ggml_context* ctx,
- struct ggml_tensor* a,
- struct ggml_tensor* b) {
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int dim) {
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+
+ int64_t ne[GGML_MAX_DIMS];
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
+ if (d == dim) {
+ ne[d] = a->ne[d] + b->ne[d];
+ continue;
+ }
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
+ ne[d] = a->ne[d];
+ }
bool is_node = false;
@@ -4893,7 +4904,9 @@ struct ggml_tensor * ggml_concat(
is_node = true;
}
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+
+ ggml_set_op_params_i32(result, 0, dim);
result->op = GGML_OP_CONCAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5013,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
result->op = GGML_OP_LEAKY_RELU;
@@ -10967,26 +10981,29 @@ static void ggml_compute_forward_concat_f32(
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+ GGML_ASSERT(dim >= 0 && dim < 4);
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = src0->ne[dim];
+
+ const float * x;
+
+ // TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
- if (i2 < ne02) { // src0
- for (int i1 = 0; i1 < ne1; i1++) {
- for (int i0 = 0; i0 < ne0; i0++) {
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
-
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
- *y = *x;
- }
- }
- } // src1
- else {
- for (int i1 = 0; i1 < ne1; i1++) {
- for (int i0 = 0; i0 < ne0; i0++) {
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
-
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
- *y = *x;
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < ne0; i0++) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
+ } else {
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
}
+
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+ *y = *x;
}
}
}
@@ -10994,7 +11011,7 @@ static void ggml_compute_forward_concat_f32(
}
static void ggml_compute_forward_concat(
- const struct ggml_compute_params* params,
+ const struct ggml_compute_params * params,
struct ggml_tensor* dst) {
const struct ggml_tensor * src0 = dst->src[0];