summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-13 17:55:04 +0300
committerGitHub <noreply@github.com>2025-05-13 17:55:04 +0300
commit13740622e973b78ae662bbb785c2fc5926a324eb (patch)
treeff928528dea67b6bfe0c1461129cacfd29b9d586
parent0c57f84dc41aa756dae7b1aaee0d3db6ecc14300 (diff)
Fix SER (CPU) (#415)
* Fixing SER bugs * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c33
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp42
2 files changed, 51 insertions, 24 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index d82466e0..94defa47 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -12472,6 +12472,11 @@ static void ggml_compute_forward_sum_rows_f32(
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
ggml_vec_sum_f32(ne00, &row_sum, src_row);
+ if (!isfinite(row_sum)) {
+ fprintf(stderr, "Oops(%s, %s): found %g for i1 = %d, i2 = %d, i3 = %d. ne00 = %d\n", __func__, dst->name,
+ (double)row_sum, (int)i1, (int)i2, (int)i3, (int)ne00);
+ exit(1);
+ }
dst_row[0] = row_sum;
}
}
@@ -14759,6 +14764,18 @@ static void ggml_compute_forward_mul_mat_id(
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+ GGML_ASSERT(ids->ne[1] == dst->ne[2]);
+ for (int64_t iid1 = ith; iid1 < ids->ne[1]; iid1 += nth) {
+ for (int id = 0; id < n_ids; ++id) {
+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+ if (i02 < 0 || i02 >= n_as) {
+ // This is needed for SER. If fewer experts have been activated for this row, we need to
+ // clear it, else there could be garbage that leads to NaNs later on.
+ memset((char *)dst->data + id*dst->nb[1] + iid1*dst->nb[2], 0, dst->ne[0]*sizeof(float));
+ }
+ }
+ }
+
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -15012,6 +15029,18 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+ GGML_ASSERT(ids->ne[1] == dst->ne[2]);
+ for (int64_t iid1 = ith; iid1 < ids->ne[1]; iid1 += nth) {
+ for (int id = 0; id < n_ids; ++id) {
+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+ if (i02 < 0 || i02 >= n_as) {
+ // This is needed for SER. If fewer experts have been activated for this row, we need to
+ // clear it, else there could be garbage that leads to NaNs later on.
+ memset((char *)dst->data + id*dst->nb[1] + iid1*dst->nb[2], 0, dst->ne[0]*sizeof(float));
+ }
+ }
+ }
+
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -15916,7 +15945,7 @@ static void ggml_compute_forward_get_rows_f16(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
} else {
- memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
+ memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
}
}
@@ -15960,7 +15989,7 @@ static void ggml_compute_forward_get_rows_bf16(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
} else {
- memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
+ memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
}
}
}
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 3cb7573b..92f58d55 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -458,31 +458,29 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
if (r2 <= 8) {
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
- int nx64 = Nx/64;
- int nchunk64 = nx64*ne02;
- for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
- int i02 = ichunk/nx64;
- int ix = 64*(ichunk - i02*nx64);
- DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
- mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
- }
- int ix0 = 64*nx64;
- if (ix0 < Nx) {
- nx32 -= 2*nx64;
- nchunk = nx32*ne02;
- for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
- int i02 = ichunk/nx32;
- int ix = ix0 + 32*(ichunk - i02*nx32);
+ int ny = mm.funcs.size();
+ while (ny > 0 && !mm.funcs[ny-1]) --ny;
+ if (ny >= r2) {
+ int nx64 = Nx/64;
+ int nchunk64 = nx64*ne02;
+ for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
+ int i02 = ichunk/nx64;
+ int ix = 64*(ichunk - i02*nx64);
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
- mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
+ }
+ int ix0 = 64*nx64;
+ if (ix0 < Nx) {
+ nx32 -= 2*nx64;
+ nchunk = nx32*ne02;
+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ int i02 = ichunk/nx32;
+ int ix = ix0 + 32*(ichunk - i02*nx32);
+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
+ }
}
}
- //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
- // int i02 = ichunk/nx32;
- // int ix = 32*(ichunk - i02*nx32);
- // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
- // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
- //}
return true;
}
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {