summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu293
1 files changed, 229 insertions, 64 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 85f7a293..9e1acd3f 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1,13 +1,15 @@
#include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
#include <cstddef>
#include <cstdint>
-#include <cinttypes>
#include <float.h>
#include <limits>
#include <stdint.h>
#include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@@ -1684,31 +1686,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
}
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
- const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
- if (col >= ncols) {
+static __global__ void k_get_rows(
+ const void * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
return;
}
- const int r = y[row];
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
- // copy x[r*ncols + col] to dst[row*ncols + col]
- const int xi = r*ncols + col;
- const int di = row*ncols + col;
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
- const int ib = xi/qk; // block index
- const int iqs = (xi%qk)/qr; // quant index
- const int iybs = di - di%qk; // y block start index
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
- dequantize_kernel(x, ib, iqs, v);
+ dequantize_kernel(src0_row, ib, iqs, v);
- dst[iybs + iqs + 0] = v.x;
- dst[iybs + iqs + y_offset] = v.y;
+ dst_row[iybs + iqs + 0] = v.x;
+ dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = src0_row[i00];
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -5053,11 +5089,69 @@ static __global__ void im2col_f32_f16(
}
template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
- const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
- const dim3 block_nums(block_num_x, nrows, 1);
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
}
template<float (*bin_op)(const float, const float)>
@@ -5069,7 +5163,6 @@ struct bin_bcast_cuda {
GGML_TENSOR_BINARY_OP_LOCALS
-
int nr0 = ne10/ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
@@ -5117,26 +5210,28 @@ struct bin_bcast_cuda {
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
- //size_t nb0 = cnb0[0];
+ size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
- //size_t nb10 = cnb1[0];
+ size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
- //size_t s0 = nb0 / sizeof(src1_t);
+ size_t s0 = nb0 / sizeof(src1_t);
size_t s1 = nb1 / sizeof(src1_t);
size_t s2 = nb2 / sizeof(src1_t);
size_t s3 = nb3 / sizeof(src1_t);
- //size_t s10 = nb10 / sizeof(src1_t);
+ size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s10 == 1);
const int block_size = 128;
@@ -6447,36 +6542,34 @@ static void ggml_cuda_op_get_rows(
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(dst));
- const int ncols = src0->ne[0];
- const int nrows = ggml_nelements(src1);
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) {
case GGML_TYPE_F16:
- get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_F32:
- get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
default:
// TODO: k-quants
@@ -8234,36 +8327,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
}
#endif
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-// const bool use_tensor_cores = true;
-//#else
-// const bool use_tensor_cores = false;
-//#endif
-
ggml_cuda_mul_mat_id_cublas(dst);
-
// TODO: mmq/mmv support
-#else
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const int id = dst->op_params[0];
+#endif
- int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
- int32_t a_id;
- CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ const struct ggml_tensor * ids = src0;
+ const int32_t id = ((int32_t *) dst->op_params)[0];
+ const int32_t n_as = ((int32_t *) dst->op_params)[1];
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
+ std::vector<char> ids_host(ggml_nbytes(ids));
+
+ if (ids->backend == GGML_BACKEND_GPU) {
+ const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ } else {
+ memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+ }
+
+ const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+ const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+ ggml_tensor_extra_gpu src1_row_extra;
+ ggml_tensor_extra_gpu dst_row_extra;
+
+ ggml_tensor src1_row = *src1;
+ ggml_tensor dst_row = *dst;
+
+ src1_row.ne[1] = 1;
+ dst_row.ne[1] = 1;
+
+ src1_row.nb[2] = src1_row.nb[1];
+ dst_row.nb[2] = dst_row.nb[1];
+
+ src1_row.nb[3] = src1_row.nb[1];
+ dst_row.nb[3] = dst_row.nb[1];
+
+ src1_row.extra = &src1_row_extra;
+ dst_row.extra = &dst_row_extra;
- ggml_cuda_mul_mat(src0, src1, dst);
-#endif
- (void) _src0;
- (void) _src1;
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+ //int32_t row_id;
+ //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+ const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
+
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+ src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+ src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+ dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+ dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+ ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+ }
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9181,6 +9307,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
}
return true;
} break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ return false;
+ } break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
@@ -9188,7 +9353,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_REPEAT:
- case GGML_OP_GET_ROWS:
case GGML_OP_DUP:
case GGML_OP_ADD:
case GGML_OP_MUL:
@@ -9197,7 +9361,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
- case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -9264,7 +9427,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
UNUSED(params);
}
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
int device_count = ggml_cuda_get_device_count();
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
for (int i = 0; i < device_count; i++) {