summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorslaren <2141330+slaren@users.noreply.github.com>2023-04-20 03:14:14 +0200
committerGitHub <noreply@github.com>2023-04-20 03:14:14 +0200
commit02d6988121510c067e06d498a273a351a888f5b9 (patch)
tree98c6204ad4f3db40bc49595bb7705e8bcd699e5d /ggml.c
parent834695fe3a3ed2a962e774c9615e3f7b41d360a8 (diff)
Improve cuBLAS performance by dequantizing on the GPU (#1065)
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c80
1 files changed, 55 insertions, 25 deletions
diff --git a/ggml.c b/ggml.c
index 431cdb9c..9a343085 100644
--- a/ggml.c
+++ b/ggml.c
@@ -150,23 +150,25 @@ inline static void* ggml_aligned_malloc(size_t size) {
#elif defined(GGML_USE_CUBLAS)
#include <cublas_v2.h>
#include <cuda_runtime.h>
-#define CUDA_CHECK(err) \
- do { \
- cudaError_t err_ = (err); \
- if (err_ != cudaSuccess) { \
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
- cudaGetErrorString(err_)); \
- exit(1); \
- } \
+#include "ggml-cuda.h"
+
+#define CUDA_CHECK(err) \
+ do { \
+ cudaError_t err_ = (err); \
+ if (err_ != cudaSuccess) { \
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
+ cudaGetErrorString(err_)); \
+ exit(1); \
+ } \
} while (0)
-#define CUBLAS_CHECK(err) \
- do { \
- cublasStatus_t err_ = (err); \
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
- exit(1); \
- } \
+#define CUBLAS_CHECK(err) \
+ do { \
+ cublasStatus_t err_ = (err); \
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
+ exit(1); \
+ } \
} while (0)
static cublasHandle_t cublasH = NULL;
@@ -177,6 +179,7 @@ static void init_cublas(void) {
CUBLAS_CHECK(cublasCreate(&cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
// configure logging to stdout
@@ -7311,7 +7314,6 @@ static void ggml_compute_forward_mul_mat_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7323,6 +7325,7 @@ static void ggml_compute_forward_mul_mat_f32(
}
}
#if defined(GGML_USE_CUBLAS)
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
@@ -7535,7 +7538,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7553,6 +7555,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
}
#if defined(GGML_USE_CUBLAS)
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
@@ -7722,13 +7725,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
return;
}
- float * const wdata = params->wdata;
- dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-
#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
+ float *d_Q = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
@@ -7738,10 +7739,41 @@ static void ggml_compute_forward_mul_mat_q_f32(
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+ CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
+
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
+ if (type == GGML_TYPE_Q4_0) {
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
+ }
+ else if (type == GGML_TYPE_Q4_1) {
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
+ }
+ else if (type == GGML_TYPE_Q4_2) {
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+#else
+ float * const wdata = params->wdata;
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+#if defined(GGML_USE_CUBLAS)
+ // copy and dequantize on device
+ CUDA_CHECK(
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
+
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
+ CUDA_CHECK(cudaGetLastError());
+#else
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7749,15 +7781,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
id += ne00;
}
}
-
const float * x = wdata;
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+#endif
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
@@ -7770,7 +7799,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7783,9 +7811,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
#if defined(GGML_USE_CUBLAS)
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
+ CUDA_CHECK(cudaFree(d_Q));
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);