summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c92
1 files changed, 62 insertions, 30 deletions
diff --git a/ggml.c b/ggml.c
index 939ab4d6..d86e5942 100644
--- a/ggml.c
+++ b/ggml.c
@@ -3776,6 +3776,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
(t1->ne[3]%t0->ne[3] == 0);
}
+static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
+}
+
static inline int ggml_up32(int n) {
return (n + 31) & ~31;
}
@@ -4658,11 +4664,15 @@ struct ggml_tensor * ggml_mul_impl(
struct ggml_tensor * a,
struct ggml_tensor * b,
bool inplace) {
- GGML_ASSERT(ggml_are_same_shape(a, b));
+ // TODO: support less-strict constraint
+ // GGML_ASSERT(ggml_can_repeat(b, a));
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true;
}
@@ -7960,7 +7970,7 @@ static void ggml_compute_forward_mul_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@@ -7968,10 +7978,25 @@ static void ggml_compute_forward_mul_f32(
const int ith = params->ith;
const int nth = params->nth;
- const int nr = ggml_nrows(src0);
- const int64_t ne0 = src0->ne[0];
- const int64_t ne1 = src0->ne[1];
- const int64_t ne2 = src0->ne[2];
+#ifdef GGML_USE_CUBLAS
+ if (src1->backend == GGML_BACKEND_CUDA) {
+ if (ith == 0) {
+ ggml_cuda_mul(src0, src1, dst);
+ }
+ return;
+ }
+#endif
+
+ const int64_t nr = ggml_nrows(src0);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+ const int64_t ne13 = src1->ne[3];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
@@ -7990,44 +8015,51 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(ne00 == ne10);
if (nb10 == sizeof(float)) {
- for (int ir = ith; ir < nr; ir += nth) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+ for (int64_t ir = ith; ir < nr; ir += nth) {
+ // src0 and dst are same shape => same indices
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_vec_mul_f32);
- vDSP_vmul(
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
- ne0);
+ vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
#else
- ggml_vec_mul_f32(ne0,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+ ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
// }
// }
}
} else {
// src1 is not contiguous
- for (int ir = ith; ir < nr; ir += nth) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- for (int i0 = 0; i0 < ne0; i0++) {
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+ for (int64_t ir = ith; ir < nr; ir += nth) {
+ // src0 and dst are same shape => same indices
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
}