diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 287 |
1 files changed, 264 insertions, 23 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8a3beb0e..b6a7754d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -32,9 +32,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } \ } while (0) +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); +typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization #define QK4_0 32 +#define QR4_0 2 typedef struct { float d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants @@ -42,6 +48,7 @@ typedef struct { static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 +#define QR4_1 2 typedef struct { float d; // delta float m; // min @@ -50,6 +57,7 @@ typedef struct { static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 +#define QR5_0 2 typedef struct { half d; // delta uint8_t qh[4]; // 5-th bit of quants @@ -58,6 +66,7 @@ typedef struct { static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); #define QK5_1 32 +#define QR5_1 2 typedef struct { half d; // delta half m; // min @@ -67,12 +76,100 @@ typedef struct { static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 +#define QR8_0 1 typedef struct { float d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +#define CUDA_DMMV_BLOCK_SIZE 32 + +static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const float d = x[ib].d; + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + v0 = (vi0 - 8)*d; + v1 = (vi1 - 8)*d; +} + +static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const float d = x[ib].d; + const float m = x[ib].m; + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + v0 = vi0*d + m; + v1 = vi1*d + m; +} + +static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const float d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16; + const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16; + + v0 = x0*d; + v1 = x1*d; +} + +static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const float d = x[ib].d; + const float m = x[ib].m; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0); + const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1); + + v0 = x0*d + m; + v1 = x1*d + m; +} + +static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const float d = x[ib].d; + + const int8_t vi0 = x[ib].qs[iqs + 0]; + const int8_t vi1 = x[ib].qs[iqs + 1]; + + v0 = vi0*d; + v1 = vi1*d; +} + +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const half * x = (const half *) vx; + + v0 = __half2float(x[ib + 0]); + v1 = __half2float(x[ib + 1]); +} + static __global__ void dequantize_block_q4_0(const void * vx, float * y) { static const int qk = QK4_0; @@ -173,6 +270,44 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } +template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel> +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + const int y_offset = qr == 1 ? 1 : qk/2; + + __shared__ float tmp[block_size]; // separate sum for each thread + tmp[tid] = 0; + + for (int i = 0; i < ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; + const int ib = (row*ncols + col)/qk; // block index + const int iqs = (col%qk)/qr; // quant index + const int iybs = col - col%qk; // y block start index + + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs, v0, v1); + + // matrix multiplication + tmp[tid] += v0 * y[iybs + iqs + 0]; + tmp[tid] += v1 * y[iybs + iqs + y_offset]; + } + + // sum up partial sums and write back result + __syncthreads(); + for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} + static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y); @@ -198,6 +333,36 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y); } +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0> + <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1> + <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0> + <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1> + <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0> + <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); +} + // TODO: optimize static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { const half * x = (const half *) vx; @@ -211,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y); } +static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16> + <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); +} + static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -230,8 +401,27 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } +static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_mul_mat_vec_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_mul_mat_vec_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_mul_mat_vec_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_mul_mat_vec_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_mul_mat_vec_q8_0_cuda; + case GGML_TYPE_F16: + return dequantize_mul_mat_vec_q8_0_cuda; + default: + return nullptr; + } +} + // buffer pool for cuda -#define MAX_CUDA_BUFFERS 16 +#define MAX_CUDA_BUFFERS 256 struct scoped_spin_lock { std::atomic_flag& lock; @@ -528,6 +718,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; const ggml_type type = src0->type; + const bool mul_mat_vec = ne11 == 1; const float alpha = 1.0f; const float beta = 0.0f; @@ -538,12 +729,16 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); size_t x_size, y_size, d_size, q_size; - float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_X = nullptr; + if (!mul_mat_vec) { + d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + } float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type); + dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type); GGML_ASSERT(to_fp32_cuda != nullptr); for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -553,31 +748,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS]; cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS]; - float * c_X = d_X + i * x_ne; float * c_Y = d_Y + i * y_ne; float * c_D = d_D + i * d_ne; char * c_Q = d_Q + i * q_sz; - // copy src0 and convert to fp32 on device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); - to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + // copy src0 to device if necessary + if (src0->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); + } else if (src0->backend == GGML_BACKEND_CUDA) { + c_Q = ((char *) src0->data) + i * q_sz; + } else { + GGML_ASSERT(false); + } + if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); - // copy src1 to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); - // wait for conversion - CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); - // compute - CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); - CUBLAS_CHECK( - cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, c_X, ne00, - c_Y, ne10, - &beta, c_D, ne01)); + // compute + dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); + CUDA_CHECK(cudaGetLastError()); + + } else { // general dequantization kernel + cuBLAS matrix matrix multiplication + float * c_X = d_X + i * x_ne; + + // convert src0 to fp32 on device + to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // wait for conversion + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + } // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); @@ -586,7 +804,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } CUDA_CHECK(cudaDeviceSynchronize()); - ggml_cuda_pool_free(d_X, x_size); + if (!mul_mat_vec) { + ggml_cuda_pool_free(d_X, x_size); + } ggml_cuda_pool_free(d_Y, y_size); ggml_cuda_pool_free(d_D, d_size); ggml_cuda_pool_free(d_Q, q_size); @@ -602,8 +822,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) { return true; } @@ -655,3 +874,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct return 0; } } + +void ggml_cuda_transform_tensor(ggml_tensor * tensor) { + const int64_t ne0 = tensor->ne[0]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2]; + const int64_t ne3 = tensor->ne[3]; + + const ggml_type type = tensor->type; + const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); + + size_t q_size; + char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); + + cudaStream_t cudaStream2 = g_cudaStreams2[0]; + + // copy tensor to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); + CUDA_CHECK(cudaDeviceSynchronize()); + + tensor->data = d_Q; + tensor->backend = GGML_BACKEND_CUDA; +} |