summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu45
1 files changed, 39 insertions, 6 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index eb244f40..5a2701cf 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -227,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_row_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_row_q4_1_cuda;
+ case GGML_TYPE_Q4_2:
+ return dequantize_row_q4_2_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_row_q5_0_cuda;
+ case GGML_TYPE_Q5_1:
+ return dequantize_row_q5_1_cuda;
+ case GGML_TYPE_Q8_0:
+ return dequantize_row_q8_0_cuda;
+ default:
+ return nullptr;
+ }
+}
+
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 16
@@ -286,18 +305,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
CUDA_CHECK(cudaFree(ptr));
}
-cublasHandle_t g_cublasH = NULL;
-cudaStream_t g_cudaStream = NULL;
+cublasHandle_t g_cublasH = nullptr;
+cudaStream_t g_cudaStream = nullptr;
+cudaStream_t g_cudaStream2 = nullptr;
+cudaEvent_t g_cudaEvent = nullptr;
-void ggml_init_cublas(void) {
- if (g_cublasH == NULL) {
+void ggml_init_cublas() {
+ if (g_cublasH == nullptr) {
// create cublas handle, bind a stream
CUBLAS_CHECK(cublasCreate(&g_cublasH));
-
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
-
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
+ // create additional stream and event for synchronization
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
+ CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
+
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
}
@@ -330,3 +353,13 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
return cudaSuccess;
}
}
+
+void * ggml_cuda_host_malloc(size_t size) {
+ void * ptr;
+ CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
+ return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+ CUDA_CHECK(cudaFreeHost(ptr));
+}