summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
Diffstat (limited to 'ggml')
-rw-r--r--ggml/src/ggml.c32
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp61
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp53
3 files changed, 134 insertions, 12 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 25694fc7..83a48cb6 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21781,19 +21781,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const struct ggml_tensor * q = node->src[0];
const struct ggml_tensor * k = node->src[1];
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
- int nstep_k = k->ne[1]/32;
- int gcd_k = simple_gcd(nstep_k, n_tasks);
- if (gcd_k > 1) {
- int nth_k = n_tasks/gcd_k;
- int rk2 = q->ne[2]/k->ne[2];
- int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
- size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
- if (ggml_is_quantized(k->type)) {
- enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
- size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
- size += q->ne[2]*row_size;
- }
+ if (k->ne[2] > 1) {
+ int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
+ int nstep_k = k->ne[2]*k->ne[1]/nk;
+ size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
+ size_t size = nstep_k*result_size;
cur = MAX(cur, size);
+ } else {
+ int nstep_k = k->ne[1]/32;
+ int gcd_k = simple_gcd(nstep_k, n_tasks);
+ if (gcd_k > 1) {
+ int nth_k = n_tasks/gcd_k;
+ int rk2 = q->ne[2]/k->ne[2];
+ int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
+ size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
+ if (ggml_is_quantized(k->type)) {
+ enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
+ size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
+ size += q->ne[2]*row_size;
+ }
+ cur = MAX(cur, size);
+ }
}
}
#endif
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
index e302c944..3f1d6dc2 100644
--- a/ggml/src/iqk/iqk_flash_attn.cpp
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -153,6 +153,67 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
}
}
+ if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
+ int nk = 32 * (nek2*nek1/(32*nth));
+ int nkk = (nek1 + nk - 1)/nk;
+ int nstep_k = nek2*nkk;
+ auto result_size = (Dv + 16)*rk2*sizeof(float);
+ //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
+ for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
+ int ik02 = istep_k/nkk;
+ int ik01 = nk*(istep_k - ik02*nkk);
+ int this_nk = ik01 + nk <= nek1 ? nk : nek1 - ik01;
+ if (this_nk <= 0) break;
+ auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
+ auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
+ auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
+ auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
+ auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
+ if (!iqk_flash_attn_impl(int_type_k, int_type_v,
+ Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
+ this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
+ scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
+ }
+
+ barrier(barrier_data);
+
+ // We have nkk results for each head
+ for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
+ // ik02*rk2 + il = iq2 (il = 0...rk2-1) => ik02 = iq2/rk2, il = iq2%rk2;
+ int ik02 = iq2/rk2;
+ int il = iq2 - ik02*rk2;
+ auto Racc = qkv + iq2*nb1/sizeof(float);
+ std::memset(Racc, 0, Dv*sizeof(float));
+ float M = -INFINITY, S = 0;
+ for (int ikk = 0; ikk < nkk; ++ikk) {
+ int istep_k = ik02*nkk + ikk;
+ auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
+ const float * R = this_result + il*Dv;
+ const float * Mj = this_result + Dv*rk2;
+ const float * Sj = Mj + rk2;
+ if (Mj[il] == -INFINITY) continue;
+ if (Mj[il] > M) {
+ if (M == -INFINITY) {
+ std::memcpy(Racc, R, Dv*sizeof(float));
+ S = Sj[il];
+ } else {
+ float c = exp(M - Mj[il]);
+ S = c*S + Sj[il];
+ for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
+ }
+ M = Mj[il];
+ } else {
+ float c = exp(Mj[il] - M);
+ S += c*Sj[il];
+ for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
+ }
+ }
+ float norm = S > 0 ? 1/S : 1;
+ for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
+ }
+ return true;
+ }
+
// I keep changing my mind what is the best strategy to split the threads when processing
// multiple heads. This is my current thinking, the commented out code below was the previous.
int ntg = nth/simple_gcd(neq2*neq3, nth);
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 78270f5e..424a65af 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -451,6 +451,51 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
auto r3 = ne13 / ne03;
if (ne13 == 1 && Ny == 1 && r2 > 1) {
+ if (Nx >= 256 && Nx%32 == 0) {
+ int nx32 = Nx/32;
+ int nchunk = nx32*ne02;
+ if (r2 <= 8) {
+ MulMat mm;
+ if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
+ int nx64 = Nx/64;
+ int nchunk64 = nx64*ne02;
+ for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
+ int i02 = ichunk/nx64;
+ int ix = 64*(ichunk - i02*nx64);
+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
+ }
+ int ix0 = 64*nx64;
+ if (ix0 < Nx) {
+ nx32 -= 2*nx64;
+ nchunk = nx32*ne02;
+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ int i02 = ichunk/nx32;
+ int ix = ix0 + 32*(ichunk - i02*nx32);
+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
+ }
+ }
+ //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ // int i02 = ichunk/nx32;
+ // int ix = 32*(ichunk - i02*nx32);
+ // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
+ //}
+ return true;
+ }
+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ int i02 = ichunk/nx32;
+ int ix = ichunk - i02*nx32;
+ if (!iqk_mul_mat(32, r2, ne00,
+ typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA,
+ typeB, (const char *)B + i02*r2*nb12, nb12,
+ C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false;
+
+ }
+ return true;
+ }
+ //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
int gcd = simple_gcd(ne02, nth);
int counter = 0;
for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -17153,6 +17198,14 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
+ else if (nq1 >= 4) {
+ FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ }
+ else if (nq1 >= 2) {
+ FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ }
else {
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);