diff options
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 182 |
1 files changed, 115 insertions, 67 deletions
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 0a1d779e..6def49ef 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -861,61 +861,60 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; + mma_A A[ntx]; + float dA[ntx][mma_C::ne/2]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - const int k0 = k00 + k01; - - A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); - } - -#pragma unroll + #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + mma_B B; + float dB[mma_C::ne/2]; + B.load(y_qs + k01, MMQ_TILE_Y_K); + #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - const int k0 = k00 + k01; - - dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + const int j = mma_C::get_j(l); + if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - mma_B B; - float dB[mma_C::ne/2]; - + #pragma unroll + for (int n = 0; n < ntx; ++n) { + A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + #pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } + mma_C C; + C.mma_K8(A[n], B); + #pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2]; + } + } + #pragma unroll + for (int j0 = ntx*mma_C::J; j0 < mmq_x; j0 += ntx*mma_C::J) { B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll + #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - - if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { - dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } else { dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } - -#pragma unroll + #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_0], B); - -#pragma unroll + C.mma_K8(A[n], B); + #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2]; } } } @@ -2701,6 +2700,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin } } +//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks( +// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { +// +//#ifdef INT8_MMA_AVAILABLE +// int * x_qs = (int *) x_tile; +// float * x_df = (float *) (x_qs + WARP_SIZE*2); +//#else +// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); +// int * x_qs = (int *) x_tile; +// float * x_df = (float *) (x_qs + txs.qs); +//#endif // INT8_MMA_AVAILABLE +// +// const int kbx = 0; // threadIdx.x / QI4_XS +// const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS +// +//#pragma unroll +// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { +// int i = i0 + threadIdx.y; +// +// if (need_check) { +// i = min(i, i_max); +// } +// +// const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx; +// +// auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4); +// const int aux_q4 = get_int_b4(bxi->qs, kqsx); +// const int2 v = get_int_from_table_16(aux_q4, values); +// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +//#ifdef INT8_MMA_AVAILABLE +// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; +// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +//#else +// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; +// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +//#endif // INT8_MMA_AVAILABLE +// } +// +//#pragma unroll +// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { +// int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); +// +// if (need_check) { +// i = min(i, i_max); +// } +// +// const float * dptr = (const float *)(x + i*stride); +// const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; +// const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127; +// +//#ifdef INT8_MMA_AVAILABLE +// x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls; +//#else +// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls; +//#endif // INT8_MMA_AVAILABLE +// } +//} + template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2713,52 +2770,43 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI4_XS - const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + const int kqsx = threadIdx.x / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; if (need_check) { i = min(i, i_max); } - const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx; + const float * dptr = (const float *)(x + i*stride); + const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; + const int ls = (bxi->scales[kqsx] & 254) - 127; + auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4); - auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4); - const int aux_q4 = get_int_b4(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4, values); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; + #pragma unroll + for (int j = 0; j < 4; ++j) { + const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j); + const int2 v = get_int_from_table_16(aux_q4, values); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; #endif // INT8_MMA_AVAILABLE - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); } - - const float * dptr = (const float *)(x + i*stride); - const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; - const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127; - #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls; #else - x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls; + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls; #endif // INT8_MMA_AVAILABLE } + } + template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { |