diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 45 |
1 files changed, 39 insertions, 6 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index eb244f40..5a2701cf 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -227,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y); } +dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q4_2: + return dequantize_row_q4_2_cuda; + case GGML_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + default: + return nullptr; + } +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 16 @@ -286,18 +305,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } -cublasHandle_t g_cublasH = NULL; -cudaStream_t g_cudaStream = NULL; +cublasHandle_t g_cublasH = nullptr; +cudaStream_t g_cudaStream = nullptr; +cudaStream_t g_cudaStream2 = nullptr; +cudaEvent_t g_cudaEvent = nullptr; -void ggml_init_cublas(void) { - if (g_cublasH == NULL) { +void ggml_init_cublas() { + if (g_cublasH == nullptr) { // create cublas handle, bind a stream CUBLAS_CHECK(cublasCreate(&g_cublasH)); - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); - CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + // create additional stream and event for synchronization + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking)); + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming)); + // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } @@ -330,3 +353,13 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, return cudaSuccess; } } + +void * ggml_cuda_host_malloc(size_t size) { + void * ptr; + CUDA_CHECK(cudaMallocHost((void **) &ptr, size)); + return ptr; +} + +void ggml_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} |