diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-07-30 09:14:18 +0300 |
---|---|---|
committer | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-01 09:38:06 +0200 |
commit | f6813cac0e79837830967b4d89d3e744a0ddc6fe (patch) | |
tree | 279e5204a9e43a7ad326c3cd14cdb7bc54258801 /ggml/src | |
parent | 22d1568c1c3fa03433c65bb0abae4e731321bd31 (diff) |
Factor out iqk CUDA dot products
I cannot possibly wait for a 5 minutes nvcc compilation
each time I touch vecdotq.cuh.
Also, cmake was adding --options-file X.rsp to the nvcc
compile commands, which confuses clangd, so I have turned
that off.
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/CMakeLists.txt | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 346 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 28 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 142 |
5 files changed, 365 insertions, 169 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 9888313d..0b1ad48a 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -259,6 +259,10 @@ if (GGML_CUDA) find_package(CUDAToolkit) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0) + if (CUDAToolkit_FOUND) message(STATUS "CUDA found") diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu new file mode 100644 index 00000000..d33f4804 --- /dev/null +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -0,0 +1,346 @@ +//#include "common.cuh" +#include "iqk_mmvq.cuh" + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +namespace { +template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y> +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__global__ void iqk_mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int qi = ggml_cuda_type_traits<type>::qi; + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int nwarps = 1; + constexpr int rows_per_cuda_block = 1; +#else + constexpr int nwarps = ncols_y <= 4 ? 4 : 2; + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; + +// partial sum for each thread + float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; + + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + // sum up partial sums and write back result +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); + } + + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + } + } +} + +template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> +void iqk_mul_mat_vec_q_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); + //GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); + + int id = ggml_cuda_get_device(); + + int64_t nwarps = 1; + int64_t rows_per_cuda_block = 1; + + if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + switch(ncols_y) { + case 1: + nwarps = 4; + rows_per_cuda_block = 1; + break; + case 2: + case 3: + case 4: + nwarps = 4; + rows_per_cuda_block = 2; + break; + case 5: + case 6: + case 7: + case 8: + nwarps = 2; + rows_per_cuda_block = 2; + break; + default: + GGML_ASSERT(false); + break; + } + } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; + const dim3 block_nums(nblocks, 1, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + switch (ncols_y) { + case 1: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 2: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 3: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 4: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 5: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 6: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 7: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 8: + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + default: + GGML_ASSERT(false); + break; + } +} + +__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + const uint8_t * values = all_values + 16*(shift & 1); + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + values = all_values + 8*(shift & 2); + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +#define VDR_IQ4_K_Q8_1_MMVQ 4 +#define VDR_IQ4_K_Q8_1_MMQ 4 + +__device__ __forceinline__ float vec_dot_iq4_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq4k_values; + + // iqs is 0...28 + const int ib32 = iqs/4; + // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; + const uint16_t extra = bq4->extra >> 2*ib32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); + get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); + sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); + } + const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); + const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); + const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; + return d * (sumi1 * ls1 + sumi2 * ls2); +} + +#define VDR_IQ5_K_Q8_1_MMVQ 4 +#define VDR_IQ5_K_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq5_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq5nl_values; + + // iqs is 0...28 + const int il = iqs/2; // 0...14 + const int is = iqs%2; // 0 or 1 + const int ib32 = 2*(il/4); // 0, 2, 4, 6 + const int32_t * q8_1 = (const int *)bq8_1[ib32+0].qs + 4*is; + const int32_t * q8_2 = (const int *)bq8_1[ib32+1].qs + 4*is; + const uint32_t * q4 = (const uint32_t *)bq5->qs + 8*(ib32/2) + 4*is; + const uint32_t * qh = (const uint32_t *)bq5->qh + 4*is; + const uint16_t extra = bq5->extra >> (2*ib32 + is); + const uint8_t * values1 = all_values + 32*(extra & 1); + const uint8_t * values2 = all_values + 8*(extra & 4); + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)aux32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + uint32_t h = qh[j] >> ib32; + aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24); + v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24); + sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2); + } + // Blocks of 16: 2*ib32 + is, 2*ib32 + is + 2 + const float d5 = __half2float(bq5->d); + const uint8_t sh = bq5->scales_h[ib32/2] >> 2*(is%2); + const int ls1 = (((bq5->scales_l[ib32+0] >> 4*is) & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = (((bq5->scales_l[ib32+1] >> 4*is) & 0xf) | ((sh << 0) & 0x30)) - 32; + return d5 * (__low2float(bq8_1[ib32+0].ds) * sumi1 * ls1 + __low2float(bq8_1[ib32+1].ds) * sumi2 * ls2); +} + +//#define VDR_IQ5_K_Q8_1_MMVQ 2 +//#define VDR_IQ5_K_Q8_1_MMQ 8 +// +//static __device__ __forceinline__ float vec_dot_iq5_k_q8_1( +// const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { +// +// const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx; +// +// // iqs is 0...28 +// // iqs = 0...7 -> bq8_offset = 0, iqs = 8...15 -> bq8_offset = 2, iqs = 16...23 -> bq8_offset = 4, iqs = 24...31 -> bq8_offset = 6 +// // bq8_offset = 0 -> 0...3, bq8_offset = 2 -> 8...11, bq8_offset = 4 -> 16...19, bq8_offset = 6 -> 24...27 +// const int bq8_offset = 2*((iqs/2)/4); +// const int32_t * q8_1 = (const int *)bq8_1[bq8_offset+0].qs; +// const int32_t * q8_2 = (const int *)bq8_1[bq8_offset+1].qs; +// const uint32_t * q4 = (const uint32_t *)bq5->qs + 4*bq8_offset + ((iqs/2)%4); +// const uint32_t * qh = (const uint32_t *)bq5->qh + ((iqs/2)%4); +// const uint16_t extra = bq5->extra >> 2*bq8_offset; +// const float d5 = __half2float(bq5->d); +// +// const uint8_t * values1; +// const uint8_t * values2; +// uint32_t indx[2]; +// const uint8_t * a8 = (const uint8_t *)indx; +// int v1, v2; +// +// indx[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | (((qh[0] >> (bq8_offset+0)) << 4) & 0x10101010); +// indx[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | (((qh[0] >> (bq8_offset+1)) << 4) & 0x10101010); +// values1 = (const uint8_t *)iq5nl_values + 32*(extra & 1); +// values2 = (const uint8_t *)iq5nl_values + 8*(extra & 4); +// v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24); +// v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24); +// int s1 = ggml_cuda_dp4a(v1, q8_1[0], 0) * (((bq5->scales_l[bq8_offset+0] & 0xf) | ((bq5->scales_h[bq8_offset/2] << 4) & 0x30)) - 32); +// int s2 = ggml_cuda_dp4a(v2, q8_2[0], 0) * (((bq5->scales_l[bq8_offset+1] & 0xf) | ((bq5->scales_h[bq8_offset/2] >> 0) & 0x30)) - 32); +// +// indx[0] = ((q4[4] >> 0) & 0x0f0f0f0f) | (((qh[4] >> (bq8_offset+0)) << 4) & 0x10101010); +// indx[1] = ((q4[4] >> 4) & 0x0f0f0f0f) | (((qh[4] >> (bq8_offset+1)) << 4) & 0x10101010); +// values1 = (const uint8_t *)iq5nl_values + 16*(extra & 2); +// values2 = (const uint8_t *)iq5nl_values + 4*(extra & 8); +// v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24); +// v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24); +// int s3 = ggml_cuda_dp4a(v1, q8_1[4], 0) * (((bq5->scales_l[bq8_offset+0] >> 4) | ((bq5->scales_h[bq8_offset/2] << 2) & 0x30)) - 32); +// int s4 = ggml_cuda_dp4a(v2, q8_2[4], 0) * (((bq5->scales_l[bq8_offset+1] >> 4) | ((bq5->scales_h[bq8_offset/2] >> 2) & 0x30)) - 32); +// +// return d5*(__low2float(bq8_1[bq8_offset+0].ds) * (s1 + s3) + __low2float(bq8_1[bq8_offset+1].ds) * (s2 + s4)); +// +//} + +#define VDR_IQ2_K_Q8_1_MMVQ 4 +#define VDR_IQ2_K_Q8_1_MMQ 4 + +// TODO +static __device__ __forceinline__ float vec_dot_iq2_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + return 0; +// +// const block_iq2_k * bq4 = (const block_iq2_k *) vbq + kbx; +// const uint8_t * all_values = (const uint8_t *)iq4k_values; +// +// // iqs is 0...28 +// const int ib32 = iqs/4; +// // Why iqs/4 ? +// const int32_t * q8 = (const int *)bq8_1[ib32].qs; +// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; +// const uint16_t extra = bq4->extra >> 2*ib32; +// int v1, v2; +// int sumi1 = 0, sumi2 = 0; +// for (int j = 0; j < 4; ++j) { +// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); +// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); +// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); +// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); +// } +// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); +// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); +// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; +// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; +// return d * (sumi1 * ls1 + sumi2 * ls2); +} + +} + +void mul_mat_vec_iq2_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq4_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + +void mul_mat_vec_iq5_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh new file mode 100644 index 00000000..14e5c1c7 --- /dev/null +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -0,0 +1,14 @@ +#include "common.cuh" + +void mul_mat_vec_iq2_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq4_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_iq5_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 776ca80f..93c8ac29 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,5 +1,6 @@ #include "mmvq.cuh" #include "vecdotq.cuh" +#include "iqk_mmvq.cuh" typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); @@ -24,9 +25,6 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : - type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 : - type == GGML_TYPE_IQ5_K ? vec_dot_iq5_k_q8_1 : - type == GGML_TYPE_IQ2_K ? vec_dot_iq2_k_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : nullptr; } @@ -49,9 +47,6 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ5_K ? VDR_IQ5_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_K ? VDR_IQ2_K_Q8_1_MMVQ : 1; } @@ -349,27 +344,6 @@ static void mul_mat_vec_iq4_xs_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -static void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6690b19a..b1b465a3 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1247,150 +1247,8 @@ static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_ #define VDR_IQ4_K_Q8_1_MMVQ 4 #define VDR_IQ4_K_Q8_1_MMQ 4 -static __device__ __forceinline__ float vec_dot_iq4_k_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - - const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx; - const uint8_t * all_values = (const uint8_t *)iq4k_values; - - // iqs is 0...28 - const int ib32 = iqs/4; - // Why iqs/4 ? - const int32_t * q8 = (const int *)bq8_1[ib32].qs; - const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; - const uint16_t extra = bq4->extra >> 2*ib32; - int v1, v2; - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 4; ++j) { - const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); - get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); - sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); - sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); - } - const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); - const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); - const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; - return d * (sumi1 * ls1 + sumi2 * ls2); -} - #define VDR_IQ5_K_Q8_1_MMVQ 4 #define VDR_IQ5_K_Q8_1_MMQ 4 -static __device__ __forceinline__ float vec_dot_iq5_k_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - - const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx; - const uint8_t * all_values = (const uint8_t *)iq5nl_values; - - // iqs is 0...28 - const int il = iqs/2; // 0...14 - const int is = iqs%2; // 0 or 1 - const int ib32 = 2*(il/4); // 0, 2, 4, 6 - const int32_t * q8_1 = (const int *)bq8_1[ib32+0].qs + 4*is; - const int32_t * q8_2 = (const int *)bq8_1[ib32+1].qs + 4*is; - const uint32_t * q4 = (const uint32_t *)bq5->qs + 8*(ib32/2) + 4*is; - const uint32_t * qh = (const uint32_t *)bq5->qh + 4*is; - const uint16_t extra = bq5->extra >> (2*ib32 + is); - const uint8_t * values1 = all_values + 32*(extra & 1); - const uint8_t * values2 = all_values + 8*(extra & 4); - uint32_t aux32[2]; - const uint8_t * a8 = (const uint8_t *)aux32; - int v1, v2; - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 4; ++j) { - uint32_t h = qh[j] >> ib32; - aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); - aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); - v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24); - v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24); - sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1); - sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2); - } - // Blocks of 16: 2*ib32 + is, 2*ib32 + is + 2 - const float d5 = __half2float(bq5->d); - const uint8_t sh = bq5->scales_h[ib32/2] >> 2*(is%2); - const int ls1 = (((bq5->scales_l[ib32+0] >> 4*is) & 0xf) | ((sh << 4) & 0x30)) - 32; - const int ls2 = (((bq5->scales_l[ib32+1] >> 4*is) & 0xf) | ((sh << 0) & 0x30)) - 32; - return d5 * (__low2float(bq8_1[ib32+0].ds) * sumi1 * ls1 + __low2float(bq8_1[ib32+1].ds) * sumi2 * ls2); -} - -//#define VDR_IQ5_K_Q8_1_MMVQ 2 -//#define VDR_IQ5_K_Q8_1_MMQ 8 -// -//static __device__ __forceinline__ float vec_dot_iq5_k_q8_1( -// const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -// -// const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx; -// -// // iqs is 0...28 -// // iqs = 0...7 -> bq8_offset = 0, iqs = 8...15 -> bq8_offset = 2, iqs = 16...23 -> bq8_offset = 4, iqs = 24...31 -> bq8_offset = 6 -// // bq8_offset = 0 -> 0...3, bq8_offset = 2 -> 8...11, bq8_offset = 4 -> 16...19, bq8_offset = 6 -> 24...27 -// const int bq8_offset = 2*((iqs/2)/4); -// const int32_t * q8_1 = (const int *)bq8_1[bq8_offset+0].qs; -// const int32_t * q8_2 = (const int *)bq8_1[bq8_offset+1].qs; -// const uint32_t * q4 = (const uint32_t *)bq5->qs + 4*bq8_offset + ((iqs/2)%4); -// const uint32_t * qh = (const uint32_t *)bq5->qh + ((iqs/2)%4); -// const uint16_t extra = bq5->extra >> 2*bq8_offset; -// const float d5 = __half2float(bq5->d); -// -// const uint8_t * values1; -// const uint8_t * values2; -// uint32_t indx[2]; -// const uint8_t * a8 = (const uint8_t *)indx; -// int v1, v2; -// -// indx[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | (((qh[0] >> (bq8_offset+0)) << 4) & 0x10101010); -// indx[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | (((qh[0] >> (bq8_offset+1)) << 4) & 0x10101010); -// values1 = (const uint8_t *)iq5nl_values + 32*(extra & 1); -// values2 = (const uint8_t *)iq5nl_values + 8*(extra & 4); -// v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24); -// v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24); -// int s1 = ggml_cuda_dp4a(v1, q8_1[0], 0) * (((bq5->scales_l[bq8_offset+0] & 0xf) | ((bq5->scales_h[bq8_offset/2] << 4) & 0x30)) - 32); -// int s2 = ggml_cuda_dp4a(v2, q8_2[0], 0) * (((bq5->scales_l[bq8_offset+1] & 0xf) | ((bq5->scales_h[bq8_offset/2] >> 0) & 0x30)) - 32); -// -// indx[0] = ((q4[4] >> 0) & 0x0f0f0f0f) | (((qh[4] >> (bq8_offset+0)) << 4) & 0x10101010); -// indx[1] = ((q4[4] >> 4) & 0x0f0f0f0f) | (((qh[4] >> (bq8_offset+1)) << 4) & 0x10101010); -// values1 = (const uint8_t *)iq5nl_values + 16*(extra & 2); -// values2 = (const uint8_t *)iq5nl_values + 4*(extra & 8); -// v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24); -// v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24); -// int s3 = ggml_cuda_dp4a(v1, q8_1[4], 0) * (((bq5->scales_l[bq8_offset+0] >> 4) | ((bq5->scales_h[bq8_offset/2] << 2) & 0x30)) - 32); -// int s4 = ggml_cuda_dp4a(v2, q8_2[4], 0) * (((bq5->scales_l[bq8_offset+1] >> 4) | ((bq5->scales_h[bq8_offset/2] >> 2) & 0x30)) - 32); -// -// return d5*(__low2float(bq8_1[bq8_offset+0].ds) * (s1 + s3) + __low2float(bq8_1[bq8_offset+1].ds) * (s2 + s4)); -// -//} - #define VDR_IQ2_K_Q8_1_MMVQ 4 #define VDR_IQ2_K_Q8_1_MMQ 4 - -// TODO -static __device__ __forceinline__ float vec_dot_iq2_k_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - return 0; -// -// const block_iq2_k * bq4 = (const block_iq2_k *) vbq + kbx; -// const uint8_t * all_values = (const uint8_t *)iq4k_values; -// -// // iqs is 0...28 -// const int ib32 = iqs/4; -// // Why iqs/4 ? -// const int32_t * q8 = (const int *)bq8_1[ib32].qs; -// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; -// const uint16_t extra = bq4->extra >> 2*ib32; -// int v1, v2; -// int sumi1 = 0, sumi2 = 0; -// for (int j = 0; j < 4; ++j) { -// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); -// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); -// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); -// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); -// } -// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); -// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); -// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; -// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; -// return d * (sumi1 * ls1 + sumi2 * ls2); -} - |