summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-12-07 22:26:54 +0200
committerGitHub <noreply@github.com>2023-12-07 22:26:54 +0200
commitfe680e3d1080a765e5d3150ffd7bab189742898d (patch)
treecd8be8bf5722d10596923aef7fb44bf8a58378d7 /ggml-cuda.cu
parentbcc0eb4591bec5ec02fad3f2bdcb1b265052ea56 (diff)
sync : ggml (new ops, tests, backend, etc.) (#4359)
* sync : ggml (part 1) * sync : ggml (part 2, CUDA) * sync : ggml (part 3, Metal) * ggml : build fixes ggml-ci * cuda : restore lost changes * cuda : restore lost changes (StableLM rope) * cmake : enable separable compilation for CUDA ggml-ci * ggml-cuda : remove device side dequantize * Revert "cmake : enable separable compilation for CUDA" This reverts commit 09e35d04b1c4ca67f9685690160b35bc885a89ac. * cuda : remove assert for rope * tests : add test-backend-ops * ggml : fix bug in ggml_concat * ggml : restore `ggml_get_n_tasks()` logic in `ggml_graph_plan()` * ci : try to fix macOS * ggml-backend : remove backend self-registration * ci : disable Metal for macOS cmake build ggml-ci * metal : fix "supports family" call * metal : fix assert * metal : print resource path ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu1317
1 files changed, 975 insertions, 342 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 1200d1c8..85f7a293 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1,13 +1,13 @@
#include <algorithm>
-#include <cinttypes>
#include <cstddef>
#include <cstdint>
+#include <cinttypes>
+#include <float.h>
#include <limits>
#include <stdint.h>
#include <stdio.h>
#include <atomic>
#include <assert.h>
-#include <float.h>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@@ -70,6 +70,7 @@
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
@@ -191,7 +192,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
- exit(1); \
+ GGML_ASSERT(!"CUDA error"); \
} \
} while (0)
@@ -205,7 +206,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
- exit(1); \
+ GGML_ASSERT(!"cuBLAS error"); \
} \
} while (0)
#else
@@ -217,7 +218,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
fprintf(stderr, "current device: %d\n", id); \
- exit(1); \
+ GGML_ASSERT(!"cuBLAS error"); \
} \
} while (0)
#endif // CUDART_VERSION >= 11
@@ -434,8 +435,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
-#define CUDA_ADD_BLOCK_SIZE 256
-#define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
@@ -528,40 +527,87 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+ return b;
+}
- if (i >= kx) {
- return;
- }
- dst[i] = x[i] + y[i%ky];
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+ return a + b;
}
-static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+ return a * b;
+}
- if (i >= k) {
- return;
- }
- dst[i] = __hadd(x[i], __float2half(y[i]));
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+ return a / b;
}
-static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ int ne0, int ne1, int ne2, int ne3,
+ int ne10, int ne11, int ne12, int ne13,
+ /*int s0, */ int s1, int s2, int s3,
+ /*int s10,*/ int s11, int s12, int s13) {
+ const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
+ const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
+ const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
+ const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
- if (i >= k) {
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
- dst[i] = __half2float(x[i]) + y[i];
+
+ const int i11 = i1 % ne11;
+ const int i12 = i2 % ne12;
+ const int i13 = i3 % ne13;
+
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i_src0;
+
+ const src0_t * src0_row = src0 + i_src0;
+ const src1_t * src1_row = src1 + i_src1;
+ dst_t * dst_row = dst + i_dst;
+
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
+ const int i10 = i0 % ne10;
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+ }
}
-static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ int ne0, int ne1, int ne2, int ne3,
+ int ne10, int ne11, int ne12, int ne13,
+ /*int s0, */ int s1, int s2, int s3,
+ /*int s10,*/ int s11, int s12, int s13) {
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
- if (i >= kx) {
+ const int i3 = i/(ne2*ne1*ne0);
+ const int i2 = (i/(ne1*ne0)) % ne2;
+ const int i1 = (i/ne0) % ne1;
+ const int i0 = i % ne0;
+
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
- dst[i] = x[i] * y[i%ky];
+
+ const int i11 = i1 % ne11;
+ const int i12 = i2 % ne12;
+ const int i13 = i3 % ne13;
+
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i_src0;
+
+ const src0_t * src0_row = src0 + i_src0;
+ const src1_t * src1_row = src1 + i_src1;
+ dst_t * dst_row = dst + i_dst;
+
+ const int i10 = i0 % ne10;
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
@@ -605,12 +651,10 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
}
template <int block_size>
-static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- const float eps = 1e-5f;
-
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
@@ -4824,6 +4868,65 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
dst[i] = col * m_k + x[i];
}
+static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockIdx.y;
+ const int col = threadIdx.x;
+
+ float sum = 0.0f;
+ for (int i = col; i < ncols; i += blockDim.x) {
+ sum += x[row * ncols + i];
+ }
+
+ sum = warp_reduce_sum(sum);
+
+ if (col == 0) {
+ dst[row] = sum;
+ }
+}
+
+template<typename T>
+static inline __device__ void swap(T & a, T & b) {
+ T tmp = a;
+ a = b;
+ b = tmp;
+}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
+ // bitonic sort
+ int col = threadIdx.x;
+ int row = blockIdx.y;
+
+ if (col >= ncols) return;
+
+ const float * x_row = x + row * ncols;
+ int * dst_row = dst + row * ncols;
+
+ // initialize indices
+ if (col < ncols) {
+ dst_row[col] = col;
+ }
+ __syncthreads();
+
+ for (int k = 2; k <= ncols; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+ swap(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+ swap(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.y*blockIdx.y + threadIdx.y;
const int row = blockDim.x*blockIdx.x + threadIdx.x;
@@ -4833,8 +4936,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
}
const int i = row*ncols + col;
- // dst[i] = col > n_past + row ? -INFINITY : x[i];
- dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+ //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+ //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
@@ -4956,25 +5060,119 @@ static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
}
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
- const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
- add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
-
-static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
- add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
-}
-
-static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
- add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
-}
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_cuda {
+ template<typename src0_t, typename src1_t, typename dst_t>
+ void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+
+ int nr0 = ne10/ne0;
+ int nr1 = ne11/ne1;
+ int nr2 = ne12/ne2;
+ int nr3 = ne13/ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ // collapse dimensions until first broadcast dimension
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ {
+ int64_t ne0 = cne0[0];
+ int64_t ne1 = cne0[1];
+ int64_t ne2 = cne0[2];
+ int64_t ne3 = cne0[3];
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ //size_t nb0 = cnb0[0];
+ size_t nb1 = cnb0[1];
+ size_t nb2 = cnb0[2];
+ size_t nb3 = cnb0[3];
+
+ //size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ //size_t s0 = nb0 / sizeof(src1_t);
+ size_t s1 = nb1 / sizeof(src1_t);
+ size_t s2 = nb2 / sizeof(src1_t);
+ size_t s3 = nb3 / sizeof(src1_t);
+
+ //size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min<unsigned int>(hne0, block_size);
+ block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums(
+ (hne0 + block_dims.x - 1) / block_dims.x,
+ (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2*ne3 + block_dims.z - 1) / block_dims.z
+ );
-static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
- const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
- mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
+ if (block_nums.z > 65535) {
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+ k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s10, */ s11, s12, s13);
+ } else {
+ k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s10, */ s11, s12, s13);
+ }
+ }
+ }
+};
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
@@ -4996,14 +5194,14 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
@@ -5025,34 +5223,10 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}
-template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
- dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
- dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
- dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
- dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
- dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
template<typename dst_t>
@@ -5101,6 +5275,64 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
#endif
}
+static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_cuda;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_cuda;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_cuda;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_F32:
+ return dequantize_block_cuda<1, 1, convert_f32>;
+ default:
+ return nullptr;
+ }
+}
+
+static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_cuda;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_cuda;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_cuda;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_F16:
+ return dequantize_block_cuda<1, 1, convert_f16>;
+ default:
+ return nullptr;
+ }
+}
+
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -5189,6 +5421,15 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<1, 1, convert_f16>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK4_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -5279,83 +5520,6 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
-static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
- dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
- dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
- GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
- const dim3 block_nums(block_num_y, 1, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- dequantize_mul_mat_vec<1, 1, convert_f16>
- <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
- switch (type) {
- case GGML_TYPE_Q4_0:
- return dequantize_row_q4_0_cuda;
- case GGML_TYPE_Q4_1:
- return dequantize_row_q4_1_cuda;
- case GGML_TYPE_Q5_0:
- return dequantize_row_q5_0_cuda;
- case GGML_TYPE_Q5_1:
- return dequantize_row_q5_1_cuda;
- case GGML_TYPE_Q8_0:
- return dequantize_row_q8_0_cuda;
- case GGML_TYPE_Q2_K:
- return dequantize_row_q2_K_cuda;
- case GGML_TYPE_Q3_K:
- return dequantize_row_q3_K_cuda;
- case GGML_TYPE_Q4_K:
- return dequantize_row_q4_K_cuda;
- case GGML_TYPE_Q5_K:
- return dequantize_row_q5_K_cuda;
- case GGML_TYPE_Q6_K:
- return dequantize_row_q6_K_cuda;
- case GGML_TYPE_F32:
- return convert_fp32_to_fp16_cuda;
- default:
- return nullptr;
- }
-}
-
-static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
- switch (type) {
- case GGML_TYPE_Q4_0:
- return dequantize_row_q4_0_cuda;
- case GGML_TYPE_Q4_1:
- return dequantize_row_q4_1_cuda;
- case GGML_TYPE_Q5_0:
- return dequantize_row_q5_0_cuda;
- case GGML_TYPE_Q5_1:
- return dequantize_row_q5_1_cuda;
- case GGML_TYPE_Q8_0:
- return dequantize_row_q8_0_cuda;
- case GGML_TYPE_Q2_K:
- return dequantize_row_q2_K_cuda;
- case GGML_TYPE_Q3_K:
- return dequantize_row_q3_K_cuda;
- case GGML_TYPE_Q4_K:
- return dequantize_row_q4_K_cuda;
- case GGML_TYPE_Q5_K:
- return dequantize_row_q5_K_cuda;
- case GGML_TYPE_Q6_K:
- return dequantize_row_q6_K_cuda;
- case GGML_TYPE_F16:
- return convert_fp16_to_fp32_cuda;
- default:
- return nullptr;
- }
-}
-
static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -5967,6 +6131,27 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
}
+static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(1, nrows, 1);
+ k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+ // bitonic sort requires ncols to be power of 2
+ GGML_ASSERT((ncols & (ncols - 1)) == 0);
+
+ const dim3 block_dims(ncols, 1, 1);
+ const dim3 block_nums(1, nrows, 1);
+ if (order == GGML_SORT_ASC) {
+ k_argsort_f32_i32<GGML_SORT_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+ } else if (order == GGML_SORT_DESC) {
+ k_argsort_f32_i32<GGML_SORT_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+ } else {
+ GGML_ASSERT(false);
+ }
+}
+
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -6059,7 +6244,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr;
}
#ifdef DEBUG_CUDA_MALLOC
- fprintf(stderr, "%s: %d buffers, max_size = %u MiB, tot_size = %u MiB, requested %u MiB\n", __func__, nnz,
+ fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif
void * ptr;
@@ -6197,7 +6382,7 @@ void * ggml_cuda_host_malloc(size_t size) {
// The allocation error can be bypassed. A null ptr will assigned out of this function.
// This can fixed the OOM error in WSL.
cudaGetLastError();
- fprintf(stderr, "WARNING: failed to allocate %.2f MiB of pinned memory: %s\n",
+ fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
size/1024.0/1024.0, cudaGetErrorString(err));
return nullptr;
}
@@ -6237,81 +6422,23 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type);
- const int64_t i1_diff = i1_high - i1_low;
+ int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
- if (nb0 == ts && nb1 == ts*(ne0/bs)) {
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
- }
- if (nb0 == ts) {
- return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream);
- }
- GGML_ASSERT(bs == 1 && "TODO: implement bs != 1");
- for (int64_t i1 = 0; i1 < i1_diff; i1++) {
- const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_ptr + i1*ts*ne0);
- // pretend the row is a matrix with cols=1
- cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream);
- if (r != cudaSuccess) { return r; }
- }
- return cudaSuccess;
-}
-
-static void ggml_cuda_op_repeat(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
- const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
- // guaranteed to be an integer due to the check in ggml_can_repeat
- const int64_t ne0 = dst->ne[0];
- const int64_t ne1 = dst->ne[1];
- const int64_t ne2 = dst->ne[2];
- const int64_t ne3 = dst->ne[3];
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
- const int64_t ne03 = src0->ne[3];
-
- const size_t nb0 = dst->nb[0];
- const size_t nb1 = dst->nb[1];
- const size_t nb2 = dst->nb[2];
- const size_t nb3 = dst->nb[3];
-
- const size_t nb00 = src0->nb[0];
- const size_t nb01 = src0->nb[1];
- const size_t nb02 = src0->nb[2];
- const size_t nb03 = src0->nb[3];
-
- const int nr0 = (int)(ne0/ne00);
- const int nr1 = (int)(ne1/ne01);
- const int nr2 = (int)(ne2/ne02);
- const int nr3 = (int)(ne3/ne03);
-
- // TODO: support for transposed / permuted tensors
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
- for (int i3 = 0; i3 < nr3; i3++) {
- for (int k3 = 0; k3 < ne03; k3++) {
- for (int i2 = 0; i2 < nr2; i2++) {
- for (int k2 = 0; k2 < ne02; k2++) {
- for (int i1 = 0; i1 < nr1; i1++) {
- for (int k1 = 0; k1 < ne01; k1++) {
- for (int i0 = 0; i0 < nr0; i0++) {
- CUDA_CHECK(cudaMemcpyAsync(
- (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
- (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
- ne00*nb0, cudaMemcpyDeviceToDevice, stream));
- }
- }
- }
- }
- }
+ } else if (nb0 == ts) {
+ return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
+ } else {
+ for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ // pretend the row is a matrix with cols=1
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
+ if (r != cudaSuccess) return r;
}
+ return cudaSuccess;
}
-
- (void) src1;
- (void) src1_d;
}
static void ggml_cuda_op_get_rows(
@@ -6358,44 +6485,55 @@ static void ggml_cuda_op_get_rows(
}
}
-inline void ggml_cuda_op_add(
+template<class op>
+inline void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
+ op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
+ op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
- add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
+ op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream);
} else {
- fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type);
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
+}
+
+static void ggml_cuda_op_repeat(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
+
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
(void) src1;
- (void) dst;
+ (void) src1_d;
}
-inline void ggml_cuda_op_mul(
+inline void ggml_cuda_op_add(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
- const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
+inline void ggml_cuda_op_mul(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
- (void) dst;
+inline void ggml_cuda_op_div(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
inline void ggml_cuda_op_gelu(
@@ -6464,7 +6602,10 @@ inline void ggml_cuda_op_norm(
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
- norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
(void) src1;
(void) dst;
@@ -7007,6 +7148,42 @@ inline void ggml_cuda_op_im2col(
(void) src0_dd;
}
+inline void ggml_cuda_op_sum_rows(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_cuda_op_argsort(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+ argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7215,7 +7392,7 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
- // const int64_t nrows0 = ggml_nrows(src0);
+ const int64_t nrows0 = ggml_nrows(src0);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
@@ -7523,6 +7700,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
}
+static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div);
+}
+
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
}
@@ -7548,7 +7729,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
}
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
- if (!g_cublas_loaded) { return false; }
+ if (!g_cublas_loaded) return false;
const int64_t ne10 = src1->ne[0];
@@ -7626,7 +7807,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
}
-__global__ static void k_compute_batched_ptrs(
+static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
@@ -7682,9 +7863,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -7741,7 +7920,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
- cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
@@ -7775,7 +7954,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
- cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
@@ -7874,6 +8053,219 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
}
}
+#if 0
+template<typename ... Srcs>
+static __global__ void k_compute_batched_ptrs_id(
+ const void ** ptrs_src, void ** ptrs_dst,
+ int ne12, int ne13,
+ int ne23,
+ int nb02, int nb03,
+ int nb12, int nb13,
+ int nb2, int nb3,
+ int r2, int r3,
+ ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
+ const half * src1_f16, half * dst_f16,
+ const int32_t * ids, const int id,
+ Srcs... src0s) {
+
+ int i = ids[id];
+
+ half * src0_f16;
+ const void * srcs_ar[] = { (const half *) src0s... };
+ if (src0_type == GGML_TYPE_F16) {
+ src0_f16 = (half *) srcs_ar[i];
+ } else {
+ src0_f16 = src0_as_f16;
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
+ to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
+ }
+ }
+
+ int i13 = blockIdx.x * blockDim.x + threadIdx.x;
+ int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+ if (i13 >= ne13 || i12 >= ne12) {
+ return;
+ }
+
+ int i03 = i13 / r3;
+ int i02 = i12 / r2;
+
+ ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
+ ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
+}
+
+static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
+ const struct ggml_tensor * ids = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src00 = dst->src[2];
+
+ const int id = dst->op_params[0];
+
+ GGML_ASSERT(!ggml_is_transposed(src00));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+
+ GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
+ const int64_t ne01 = src00->ne[1];
+ const int64_t ne02 = src00->ne[2];
+ const int64_t ne03 = src00->ne[3];
+
+ //const int64_t nb01 = src00->nb[1];
+ const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
+ const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+ const int64_t ne13 = src1->ne[3];
+
+ //const int64_t nb11 = src1->nb[1];
+ const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
+ const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
+
+ const int64_t ne1 = ggml_nelements(src1);
+ const int64_t ne = ggml_nelements(dst);
+
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
+
+ //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ //void * src0_ddq = src0_extra->data_device[g_main_device];
+ //half * src0_as_f16 = (half *) src0_ddq;
+
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+ // convert src1 to fp16
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+
+ size_t src1_as = 0;
+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+ to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+
+ size_t dst_as = 0;
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ // broadcast factors
+ const int64_t r2 = ne12/ne02;
+ const int64_t r3 = ne13/ne03;
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ // use cublasGemmBatchedEx
+ const int ne23 = ne12*ne13;
+
+ const void ** ptrs_src = nullptr;
+ void ** ptrs_dst = nullptr;
+
+ size_t ptrs_src_s = 0;
+ size_t ptrs_dst_s = 0;
+
+ ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
+ ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
+
+ int64_t src0_ne = ggml_nelements(src00);
+ half * src0_as_f16 = nullptr;
+ size_t src0_as = 0;
+ if (src00->type != GGML_TYPE_F16) {
+ src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
+ }
+
+ static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
+ dim3 block_dims(ne13, ne12);
+ k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
+ ptrs_src, ptrs_dst,
+ ne12, ne13,
+ ne23,
+ ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
+ nb12, nb13,
+ dst->nb[2], dst->nb[3],
+ r2, r3,
+ src00->type, src0_as_f16, src0_ne,
+ src1_as_f16, dst_f16,
+ (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
+ dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
+ dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
+ dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
+ dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
+ );
+ CUDA_CHECK(cudaGetLastError());
+
+ CUBLAS_CHECK(
+ cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
+ (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
+ &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
+ ne23,
+ CUBLAS_COMPUTE_16F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ if (src0_as != 0) {
+ ggml_cuda_pool_free(src0_as_f16, src0_as);
+ }
+ if (ptrs_src_s != 0) {
+ ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
+ }
+ if (ptrs_dst_s != 0) {
+ ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
+ }
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
+ ggml_cuda_pool_free(dst_f16, dst_as);
+}
+#endif
+
+static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+#if 0
+//#ifdef CUDA_USE_TENSOR_CORES
+// const bool use_tensor_cores = true;
+//#else
+// const bool use_tensor_cores = false;
+//#endif
+
+ ggml_cuda_mul_mat_id_cublas(dst);
+
+ // TODO: mmq/mmv support
+#else
+ const struct ggml_tensor * ids = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const int id = dst->op_params[0];
+
+ int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+
+ int32_t a_id;
+ CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+ GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+ const struct ggml_tensor * src0 = dst->src[a_id + 2];
+
+ ggml_cuda_mul_mat(src0, src1, dst);
+#endif
+
+ (void) _src0;
+ (void) _src1;
+}
+
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}
@@ -7965,6 +8357,16 @@ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
+static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows);
+}
+
+static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort);
+}
+
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
@@ -8220,8 +8622,9 @@ void ggml_cuda_set_main_device(const int main_device) {
main_device, g_device_count, g_main_device);
return;
}
- g_main_device = main_device;
- if (g_device_count > 1) {
+
+ if (g_main_device != main_device && g_device_count > 1) {
+ g_main_device = main_device;
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
@@ -8247,7 +8650,7 @@ void ggml_cuda_free_scratch() {
}
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
- if (!g_cublas_loaded) { return false; }
+ if (!g_cublas_loaded) return false;
ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@@ -8283,6 +8686,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_MUL:
func = ggml_cuda_mul;
break;
+ case GGML_OP_DIV:
+ func = ggml_cuda_div;
+ break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
@@ -8296,7 +8702,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
break;
default:
return false;
- } break;
+ }
+ break;
case GGML_OP_NORM:
func = ggml_cuda_norm;
break;
@@ -8309,6 +8716,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_mul_mat;
break;
+ case GGML_OP_MUL_MAT_ID:
+ if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
+ return false;
+ }
+ func = ggml_cuda_mul_mat_id;
+ break;
case GGML_OP_SCALE:
func = ggml_cuda_scale;
break;
@@ -8348,6 +8761,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
+ case GGML_OP_SUM_ROWS:
+ func = ggml_cuda_sum_rows;
+ break;
+ case GGML_OP_ARGSORT:
+ func = ggml_cuda_argsort;
+ break;
default:
return false;
}
@@ -8364,7 +8783,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
int ggml_cuda_get_device_count() {
int device_count;
- CUDA_CHECK(cudaGetDeviceCount(&device_count));
+ if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
+ return 0;
+ }
return device_count;
}
@@ -8380,27 +8801,16 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des
#define UNUSED GGML_UNUSED
-struct ggml_backend_context_cuda {
-};
-
-static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
- return GGML_CUDA_NAME;
-
- UNUSED(backend);
-}
-
-static void ggml_backend_cuda_free(ggml_backend_t backend) {
- ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
- delete cuda_ctx;
- delete backend;
-}
+// cuda buffer
struct ggml_backend_buffer_context_cuda {
- void * device;
-
+ int device;
+ void * dev_ptr = nullptr;
ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
size_t temp_tensor_extra_index = 0;
+ ggml_backend_buffer_context_cuda(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
+
~ggml_backend_buffer_context_cuda() {
delete[] temp_tensor_extras;
}
@@ -8421,41 +8831,20 @@ struct ggml_backend_buffer_context_cuda {
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
- CUDA_CHECK(cudaFree(ctx->device));
+ CUDA_CHECK(cudaFree(ctx->dev_ptr));
delete ctx;
}
static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
- return ctx->device;
-}
-
-static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
- int64_t row_low = 0;
- int64_t row_high = ggml_nrows(tensor);
- int64_t nrows_split = row_high - row_low;
-
- size_t size = ggml_nbytes_split(tensor, nrows_split);
-
- int64_t ne0 = tensor->ne[0];
-
- if (ggml_is_quantized(tensor->type)) {
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
- * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
- }
- }
-
- return size;
-
- UNUSED(buffer);
+ return ctx->dev_ptr;
}
static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
if (tensor->view_src != NULL && tensor->view_offs == 0) {
- assert(tensor->view_src->buffer->backend == buffer->backend);
+ assert(tensor->view_src->buffer->buft == buffer->buft); // TODO
tensor->backend = tensor->view_src->backend;
tensor->extra = tensor->view_src->extra;
return;
@@ -8463,7 +8852,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
- extra->data_device[g_main_device] = tensor->data;
+ extra->data_device[ctx->device] = tensor->data;
tensor->backend = GGML_BACKEND_GPU;
tensor->extra = extra;
@@ -8475,64 +8864,208 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
int64_t nrows_split = row_high - row_low;
size_t original_size = ggml_nbytes_split(tensor, nrows_split);
- size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor);
+ size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
if (padded_size > original_size && tensor->view_src == nullptr) {
- CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0]));
+ CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0]));
}
}
UNUSED(buffer);
}
+static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+ CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
+
+ UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+ CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
+
+ UNUSED(buffer);
+}
+
static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
- /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
- /* .get_base = */ ggml_backend_cuda_buffer_get_base,
- /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size,
- /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
- /* .free_tensor = */ NULL,
+ /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cuda_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
+ /* .cpy_tensor_from = */ NULL,
+ /* .cpy_tensor_to = */ NULL,
};
-static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
- ggml_cuda_set_device(g_main_device);
+// cuda buffer type
+
+static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ int device = (int) (intptr_t) buft->context;
- ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
+ ggml_cuda_set_device(device);
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
- ggml_cuda_set_device(g_main_device);
- CUDA_CHECK(cudaMalloc(&ctx->device, size));
+ void * dev_ptr;
+ CUDA_CHECK(cudaMalloc(&dev_ptr, size));
+
+ ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(device, dev_ptr);
- return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
+ return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size);
}
-static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
+
+ UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) {
+ int64_t row_low = 0;
+ int64_t row_high = ggml_nrows(tensor);
+ int64_t nrows_split = row_high - row_low;
+
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+ int64_t ne0 = tensor->ne[0];
+
+ if (ggml_is_quantized(tensor->type)) {
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
+ * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
+ }
+ }
+
+ return size;
+
+ UNUSED(buft);
+}
+
+static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+ return ggml_backend_is_cuda(backend);
+
+ UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = {
+ /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
+ /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+ /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES];
+ static bool ggml_backend_buffer_type_cuda_initialized = false;
+ if (!ggml_backend_buffer_type_cuda_initialized) {
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
+ ggml_backend_buffer_type_cuda[i] = {
+ /* .iface = */ cuda_backend_buffer_type_interface,
+ /* .context = */ (ggml_backend_buffer_type_context_t) (intptr_t) i,
+ };
+ }
+ ggml_backend_buffer_type_cuda_initialized = true;
+ }
+
+ return &ggml_backend_buffer_type_cuda[device];
+}
+
+// host buffer type
+
+static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
+ CUDA_CHECK(cudaFreeHost(ctx->dev_ptr));
+ delete ctx;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ void * ptr;
+ CUDA_CHECK(cudaMallocHost(&ptr, size));
+
+ // FIXME: this is a hack to avoid having to implement a new buffer type
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+ return buffer;
+
+ UNUSED(buft);
+}
+
+struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = {
+ /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+ /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = {
+ /* .iface = */ cuda_backend_host_buffer_type_interface,
+ /* .context = */ nullptr,
+ };
+
+ return &ggml_backend_buffer_type_cuda_host;
+}
+
+// backend
+
+struct ggml_backend_context_cuda {
+ int device;
+};
+
+static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
+ return GGML_CUDA_NAME;
+
UNUSED(backend);
}
+static void ggml_backend_cuda_free(ggml_backend_t backend) {
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+ delete cuda_ctx;
+ delete backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+ return ggml_backend_cuda_buffer_type(cuda_ctx->device);
+}
+
static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+ GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0]));
-
- UNUSED(backend);
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
}
static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+ GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-
- UNUSED(backend);
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
}
static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[cuda_ctx->device][0]));
UNUSED(backend);
}
@@ -8546,14 +9079,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
UNUSED(cgraph);
}
-[[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(!"not implemented");
UNUSED(backend);
UNUSED(plan);
}
-[[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(!"not implemented");
UNUSED(backend);
@@ -8561,7 +9094,9 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
}
static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
- ggml_cuda_set_device(g_main_device);
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+ ggml_cuda_set_main_device(cuda_ctx->device);
ggml_compute_params params = {};
params.type = GGML_TASK_COMPUTE;
@@ -8569,13 +9104,18 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
- if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) {
+ if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
continue;
- }
+
assert(node->backend == GGML_BACKEND_GPU);
+ assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+ assert(node->extra != nullptr);
+
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->backend == GGML_BACKEND_GPU);
+ assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+ assert(node->src[j]->extra != nullptr);
}
}
@@ -8612,27 +9152,98 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
UNUSED(backend);
}
+static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ return true;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ {
+ struct ggml_tensor * a;
+ struct ggml_tensor * b;
+ if (op->op == GGML_OP_MUL_MAT) {
+ a = op->src[0];
+ b = op->src[1];
+ } else {
+ a = op->src[2];
+ b = op->src[1];
+ }
+ if (a->ne[3] != b->ne[3]) {
+ return false;
+ }
+ return true;
+ } break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NORM:
+ case GGML_OP_REPEAT:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_DUP:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
+ case GGML_OP_ALIBI:
+ case GGML_OP_IM2COL:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_ARGSORT:
+ return true;
+ default:
+ return false;
+ }
+
+ UNUSED(backend);
+}
+
static ggml_backend_i cuda_backend_i = {
- /* .get_name = */ ggml_backend_cuda_name,
- /* .free = */ ggml_backend_cuda_free,
- /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cuda_get_alignment,
- /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
- /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
- /* .synchronize = */ ggml_backend_cuda_synchronize,
- /* .cpy_tensor_from = */ nullptr,
- /* .cpy_tensor_to = */ nullptr,
- /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create,
- /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free,
- /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute,
- /* .graph_compute = */ ggml_backend_cuda_graph_compute,
- /* .supports_op = */ nullptr,
+ /* .get_name = */ ggml_backend_cuda_name,
+ /* .free = */ ggml_backend_cuda_free,
+ /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
+ /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
+ /* .cpy_tensor_from_async = */ NULL,
+ /* .cpy_tensor_to_async = */ NULL,
+ /* .synchronize = */ ggml_backend_cuda_synchronize,
+ /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create,
+ /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free,
+ /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute,
+ /* .graph_compute = */ ggml_backend_cuda_graph_compute,
+ /* .supports_op = */ ggml_backend_cuda_supports_op,
};
-ggml_backend_t ggml_backend_cuda_init() {
+ggml_backend_t ggml_backend_cuda_init(int device) {
ggml_init_cublas(); // TODO: remove from ggml.c
- ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;
+ if (device < 0 || device >= ggml_cuda_get_device_count()) {
+ fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
+ return nullptr;
+ }
+
+ // not strictly necessary, but it may reduce the overhead of the first graph_compute
+ ggml_cuda_set_main_device(device);
+
+ ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
+ /* .device = */ device
+ };
ggml_backend_t cuda_backend = new ggml_backend {
/* .interface = */ cuda_backend_i,
@@ -8641,3 +9252,25 @@ ggml_backend_t ggml_backend_cuda_init() {
return cuda_backend;
}
+
+bool ggml_backend_is_cuda(ggml_backend_t backend) {
+ return backend->iface.get_name == ggml_backend_cuda_name;
+}
+
+static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
+ ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
+ return cuda_backend;
+
+ UNUSED(params);
+}
+
+extern "C" int ggml_backend_cuda_reg_devices() {
+ int device_count = ggml_cuda_get_device_count();
+ //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
+ for (int i = 0; i < device_count; i++) {
+ char name[128];
+ snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
+ ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i);
+ }
+ return device_count;
+}