summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu92
1 files changed, 49 insertions, 43 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index fe7332b2..dbe53cee 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -61,7 +61,7 @@
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
-#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#else
@@ -190,6 +190,12 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} while (0)
#endif // CUDART_VERSION >= 11
+#if CUDART_VERSION >= 11100
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11100
+
#ifdef GGML_CUDA_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
@@ -2145,10 +2151,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI4_0;
const int kqsx = k % QI4_0;
@@ -2239,10 +2245,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI4_1;
const int kqsx = k % QI4_1;
@@ -2331,10 +2337,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI5_0;
const int kqsx = k % QI5_0;
@@ -2445,10 +2451,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI5_1;
const int kqsx = k % QI5_1;
@@ -2551,10 +2557,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI8_0;
const int kqsx = k % QI8_0;
@@ -2642,10 +2648,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI2_K;
const int kqsx = k % QI2_K;
@@ -2763,10 +2769,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI3_K;
const int kqsx = k % QI3_K;
@@ -2981,10 +2987,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI4_K; // == 0 if QK_K == 256
const int kqsx = k % QI4_K; // == k if QK_K == 256
@@ -3162,10 +3168,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI5_K; // == 0 if QK_K == 256
const int kqsx = k % QI5_K; // == k if QK_K == 256
@@ -3291,10 +3297,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- __builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < nwarps);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
+ GGML_CUDA_ASSUME(i_offset >= 0);
+ GGML_CUDA_ASSUME(i_offset < nwarps);
+ GGML_CUDA_ASSUME(k >= 0);
+ GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI6_K; // == 0 if QK_K == 256
const int kqsx = k % QI6_K; // == k if QK_K == 256
@@ -6408,7 +6414,7 @@ static void ggml_cuda_op_mul_mat(
// wait for main GPU data if necessary
if (split && (id != g_main_device || is != 0)) {
- CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
+ CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0));
}
for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
@@ -6530,7 +6536,7 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
for (int64_t id = 0; id < g_device_count; ++id) {
for (int64_t is = 0; is < is_max; ++is) {
- CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
+ CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
}
}
}