summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c48
1 files changed, 40 insertions, 8 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index faf1902d..a2bdc156 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1756,6 +1756,15 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
return type_traits[type];
}
+static inline int ggml_packed_rows(enum ggml_type type) {
+ return type == GGML_TYPE_BF16_R16 ? 16
+ : type == GGML_TYPE_Q8_K_R8 || type == GGML_TYPE_Q8_KV_R8 ||
+ type == GGML_TYPE_Q8_0_R8 || type == GGML_TYPE_Q4_0_R8 ||
+ type == GGML_TYPE_IQ4_XS_R8 ? 8
+ : type >= GGML_TYPE_Q4_0_R8 && type <= GGML_TYPE_Q8_K_R8 ? 4
+ : 1;
+}
+
//
// simd mappings
//
@@ -10119,9 +10128,11 @@ static void ggml_compute_forward_dup_f32(
}
// parallelize by rows
+ int n_packed = ggml_packed_rows(dst->type);
+ GGML_ASSERT(dst->ne[1] % n_packed == 0);
const int nr = ne01;
// number of rows per thread
- const int dr = (nr + nth - 1) / nth;
+ const int dr = n_packed*((nr/n_packed + nth - 1) / nth);
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
@@ -10173,10 +10184,10 @@ static void ggml_compute_forward_dup_f32(
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
- id += rs;
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00*n_packed);
+ id += rs*n_packed;
}
id += rs * (ne01 - ir1);
}
@@ -10441,10 +10452,15 @@ static void ggml_compute_forward_dup_bytes(
// parallelize by rows
const int nr = ne01;
+ const int n_packed = ggml_packed_rows(dst->type);
+ GGML_ASSERT(nr%n_packed == 0);
+ const int nrp = nr/n_packed;
// number of rows per thread
- const int dr = (nr + nth - 1) / nth;
+ const int drp = (nrp + nth - 1) / nth;
+ const int dr = drp*n_packed;
// row range for this thread
const int ir0 = dr * ith;
+ if (ir0 >= nr) return;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
@@ -10569,10 +10585,19 @@ static void ggml_compute_forward_dup_q(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_quantized(dst->src[0]->type));
+
int64_t nrows = ggml_nrows(dst);
int ith = params->ith;
int nth = params->nth;
+ if (dst->src[0]->type == dst->type &&
+ dst->src[0]->nb[0] == ggml_type_size(dst->type) &&
+ dst->nb[0] == ggml_type_size(dst->type)) {
+ ggml_compute_forward_dup_bytes(params, dst);
+ return;
+ }
+
if (dst->type == GGML_TYPE_Q8_0 && dst->src[0]->type == GGML_TYPE_Q8_0 &&
ggml_are_same_shape(dst, dst->src[0])) {
@@ -10626,6 +10651,10 @@ static void ggml_compute_forward_dup_q(
return;
}
+ if (dst->type != GGML_TYPE_F32) {
+ printf("%s: %s -> %s is of type %s\n", __func__, dst->src[0]->name, dst->name, ggml_type_name(dst->type));
+ GGML_ABORT("fatal error");
+ }
GGML_ASSERT(dst->type == GGML_TYPE_F32);
struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->ne[0] == dst->ne[0] && src0->nb[0] == ggml_type_size(src0->type));
@@ -10633,12 +10662,15 @@ static void ggml_compute_forward_dup_q(
ggml_to_float_t to_float = type_traits[src0->type].to_float;
GGML_ASSERT(to_float != NULL);
- int64_t n_per_thread = (nrows + nth - 1)/nth;
+ int n_packed = ggml_packed_rows(src0->type);
+ GGML_ASSERT(src0->ne[1] % n_packed == 0);
+
+ int64_t n_per_thread = n_packed*((nrows/n_packed + nth - 1)/nth);
int64_t first_row = ith*n_per_thread;
if (first_row >= nrows) return;
int64_t last_row = MIN(first_row + n_per_thread, nrows);
- for (int64_t ir = first_row; ir < last_row; ++ir) {
+ for (int64_t ir = first_row; ir < last_row; ir += n_packed) {
int64_t i03 = ir/(src0->ne[1]*src0->ne[2]);
int64_t i02 = (ir - i03*src0->ne[1]*src0->ne[2])/src0->ne[1];
int64_t i01 = ir - i03*src0->ne[1]*src0->ne[2] - i02*src0->ne[1];
@@ -10649,7 +10681,7 @@ static void ggml_compute_forward_dup_q(
const char * q = (const char *)src0->data + i03*src0->nb[3] + i02*src0->nb[2] + i01*src0->nb[1];
char * f = ( char *)dst->data + i3* dst->nb[3] + i2* dst->nb[2] + i1* dst->nb[1];
- to_float((const void *)q, (float *)f, src0->ne[0]);
+ to_float((const void *)q, (float *)f, src0->ne[0]*n_packed);
}
}