summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c9
1 files changed, 5 insertions, 4 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 4546eac3..771bc8ca 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -16154,11 +16154,12 @@ static void ggml_compute_forward_flash_attn_ext_f16(
mask && mask->type == GGML_TYPE_F16) {
int64_t work_per_slice = D*nek1*neq1;
int ntg = 1;
- if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
- else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
- else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
+ if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
+ else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
+ else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
if ((neq2*neq3)%(nth/ntg) == 0) {
- //if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1);
+ //if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d, ntg = %d, neq1/ntg = %d\n", __func__,
+ // (int)D, (int)neq2, (int)neq1, (int)nek1, ntg, (int)(neq1/ntg));
int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {