summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu535
1 files changed, 457 insertions, 78 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 989c419c..7e92c519 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -62,6 +62,7 @@
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
+#define cudaMemsetAsync hipMemsetAsync
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
@@ -419,6 +420,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+#define CUDA_GET_ROWS_BLOCK_SIZE 256
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
@@ -1574,6 +1576,34 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
+ const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int r = y[row];
+
+ // copy x[r*ncols + col] to dst[row*ncols + col]
+ const int xi = r*ncols + col;
+ const int di = row*ncols + col;
+
+ const int ib = xi/qk; // block index
+ const int iqs = (xi%qk)/qr; // quant index
+ const int iybs = di - di%qk; // y block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ dfloat2 v;
+ dequantize_kernel(x, ib, iqs, v);
+
+ dst[iybs + iqs + 0] = v.x;
+ dst[iybs + iqs + y_offset] = v.y;
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
@@ -4555,6 +4585,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}
+
+template<int qk, int qr, dequantize_kernel_t dq>
+static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(block_num_x, nrows, 1);
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+}
+
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -5703,7 +5742,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
} else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
kind = cudaMemcpyDeviceToDevice;
- struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
int id;
CUDA_CHECK(cudaGetDevice(&id));
src_ptr = (char *) extra->data_device[id];
@@ -5739,6 +5778,107 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
}
}
+static void ggml_cuda_op_repeat(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+ const int64_t ne2 = dst->ne[2];
+ const int64_t ne3 = dst->ne[3];
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ const size_t nb0 = dst->nb[0];
+ const size_t nb1 = dst->nb[1];
+ const size_t nb2 = dst->nb[2];
+ const size_t nb3 = dst->nb[3];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ const int nr0 = (int)(ne0/ne00);
+ const int nr1 = (int)(ne1/ne01);
+ const int nr2 = (int)(ne2/ne02);
+ const int nr3 = (int)(ne3/ne03);
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
+ for (int i3 = 0; i3 < nr3; i3++) {
+ for (int k3 = 0; k3 < ne03; k3++) {
+ for (int i2 = 0; i2 < nr2; i2++) {
+ for (int k2 = 0; k2 < ne02; k2++) {
+ for (int i1 = 0; i1 < nr1; i1++) {
+ for (int k1 = 0; k1 < ne01; k1++) {
+ for (int i0 = 0; i0 < nr0; i0++) {
+ CUDA_CHECK(cudaMemcpyAsync(
+ (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
+ (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
+ ne00*nb0, cudaMemcpyDeviceToDevice, stream));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ (void) src1;
+ (void) src1_d;
+}
+
+static void ggml_cuda_op_get_rows(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
+
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ const int ncols = src0->ne[0];
+ const int nrows = ggml_nelements(src1);
+
+ const int32_t * src1_i32 = (const int32_t *) src1_d;
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ case GGML_TYPE_F32:
+ get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ case GGML_TYPE_Q4_0:
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ break;
+ default:
+ // TODO: k-quants
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
inline void ggml_cuda_op_add(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6343,7 +6483,14 @@ inline void ggml_cuda_op_scale(
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const float scale = ((float *) src1->data)[0];
+ float scale;
+ // HACK: support for ggml backend interface
+ if (src1->backend == GGML_BACKEND_CPU) {
+ scale = ((float *) src1->data)[0];
+ } else {
+ // TODO: pass pointer to kernel instead of copying to host
+ CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
+ }
scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());
@@ -6362,9 +6509,9 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
- struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
@@ -6505,9 +6652,9 @@ static void ggml_cuda_op_mul_mat(
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
@@ -6585,7 +6732,7 @@ static void ggml_cuda_op_mul_mat(
if (convert_src1_to_q8_1) {
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
- if (split && src1_on_device && src1_is_contiguous) {
+ if (src1_on_device && src1_is_contiguous) {
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
}
@@ -6667,7 +6814,7 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(false);
}
- if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) {
+ if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) {
quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
}
@@ -6758,6 +6905,14 @@ static void ggml_cuda_op_mul_mat(
}
}
+static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_repeat);
+}
+
+static void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_get_rows);
+}
+
static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
}
@@ -6812,13 +6967,13 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
@@ -6843,13 +6998,13 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
const int64_t row_stride_x = nb01 / sizeof(half);
@@ -6870,11 +7025,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
}
}
- if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+ if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
- }else if (src0->type == GGML_TYPE_F32) {
+ } else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
@@ -6935,8 +7090,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
- const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
- const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
@@ -6991,8 +7146,8 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
const size_t nb1 = tensor->nb[1];
- ggml_backend backend = tensor->backend;
- struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
+ ggml_backend_type backend = tensor->backend;
+ ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
for (int64_t id = 0; id < g_device_count; ++id) {
@@ -7046,7 +7201,6 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
}
-
CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf;
@@ -7085,17 +7239,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
delete extra;
}
-static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
+static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
static size_t g_temp_tensor_extra_index = 0;
-static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
+static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (g_temp_tensor_extras == nullptr) {
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
}
size_t alloc_index = g_temp_tensor_extra_index;
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
- struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
+ ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));
return extra;
@@ -7123,7 +7277,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
return;
}
- struct ggml_tensor_extra_gpu * extra;
+ ggml_tensor_extra_gpu * extra;
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW ||
@@ -7132,7 +7286,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
size_t offset = 0;
if (tensor->op == GGML_OP_VIEW) {
@@ -7141,7 +7295,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src0_ddc + offset;
} else if (tensor->op == GGML_OP_CPY) {
- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
void * src1_ddv = src1_extra->data_device[g_main_device];
extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src1_ddv;
@@ -7183,13 +7337,13 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
}
- struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
+ ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW;
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
size_t view_offset = 0;
if (tensor->op == GGML_OP_VIEW) {
@@ -7207,7 +7361,7 @@ void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(ggml_is_contiguous(tensor));
- struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
}
@@ -7264,58 +7418,47 @@ void ggml_cuda_free_scratch() {
g_scratch_buffer = nullptr;
}
-bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
+bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
+ if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
+ return false;
+ }
+
switch (tensor->op) {
+ case GGML_OP_REPEAT:
+ func = ggml_cuda_repeat;
+ break;
+ case GGML_OP_GET_ROWS:
+ func = ggml_cuda_get_rows;
+ break;
case GGML_OP_DUP:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_dup;
break;
case GGML_OP_ADD:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_add;
break;
case GGML_OP_MUL:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_mul;
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_gelu;
break;
case GGML_UNARY_OP_SILU:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_silu;
break;
default:
return false;
} break;
case GGML_OP_NORM:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_norm;
break;
case GGML_OP_RMS_NORM:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_rms_norm;
break;
case GGML_OP_MUL_MAT:
@@ -7325,54 +7468,30 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
func = ggml_cuda_mul_mat;
break;
case GGML_OP_SCALE:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_scale;
break;
case GGML_OP_CPY:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_cpy;
break;
case GGML_OP_CONT:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_dup;
break;
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_nop;
break;
case GGML_OP_DIAG_MASK_INF:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_diag_mask_inf;
break;
case GGML_OP_SOFT_MAX:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_soft_max;
break;
case GGML_OP_ROPE:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_rope;
break;
case GGML_OP_ALIBI:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_alibi;
break;
default:
@@ -7400,3 +7519,263 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
snprintf(description, description_size, "%s", prop.name);
}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend interface
+
+#define UNUSED GGML_UNUSED
+
+struct ggml_backend_context_cuda {
+};
+
+static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
+ return GGML_CUDA_NAME;
+
+ UNUSED(backend);
+}
+
+static void ggml_backend_cuda_free(ggml_backend_t backend) {
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+ delete cuda_ctx;
+ delete backend;
+}
+
+struct ggml_backend_buffer_context_cuda {
+ void * device;
+
+ ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
+ size_t temp_tensor_extra_index = 0;
+
+ ~ggml_backend_buffer_context_cuda() {
+ delete[] temp_tensor_extras;
+ }
+
+ ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
+ if (temp_tensor_extras == nullptr) {
+ temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+ }
+
+ size_t alloc_index = temp_tensor_extra_index;
+ temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+ ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
+ memset(extra, 0, sizeof(*extra));
+
+ return extra;
+ }
+};
+
+static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
+ CUDA_CHECK(cudaFree(ctx->device));
+ delete ctx;
+}
+
+static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
+ ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
+ return ctx->device;
+}
+
+static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ int64_t row_low = 0;
+ int64_t row_high = ggml_nrows(tensor);
+ int64_t nrows_split = row_high - row_low;
+
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+ int64_t ne0 = tensor->ne[0];
+
+ if (ggml_is_quantized(tensor->type)) {
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
+ * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
+ }
+ }
+
+ return size;
+
+ UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
+
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
+ assert(tensor->view_src->buffer->backend == buffer->backend);
+ tensor->backend = tensor->view_src->backend;
+ tensor->extra = tensor->view_src->extra;
+ return;
+ }
+
+ ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
+
+ extra->data_device[g_main_device] = tensor->data;
+
+ tensor->backend = GGML_BACKEND_GPU;
+ tensor->extra = extra;
+
+ if (ggml_is_quantized(tensor->type)) {
+ // initialize padding to 0 to avoid possible NaN values
+ int64_t row_low = 0;
+ int64_t row_high = ggml_nrows(tensor);
+ int64_t nrows_split = row_high - row_low;
+
+ size_t original_size = ggml_nbytes_split(tensor, nrows_split);
+ size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor);
+
+ if (padded_size > original_size && tensor->view_src == nullptr) {
+ CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0]));
+ }
+ }
+
+ UNUSED(buffer);
+}
+
+static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
+ /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cuda_buffer_get_base,
+ /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size,
+ /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
+ /* .free_tensor = */ NULL,
+};
+
+static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
+ ggml_cuda_set_device(g_main_device);
+
+ ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
+ CUDA_CHECK(cudaMalloc(&ctx->device, size));
+ return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) {
+ return 128;
+ UNUSED(backend);
+}
+
+static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0]));
+
+ UNUSED(backend);
+}
+
+static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+
+ UNUSED(backend);
+}
+
+static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+ UNUSED(backend);
+}
+
+static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ GGML_ASSERT(!"not implemented");
+
+ return nullptr;
+
+ UNUSED(backend);
+ UNUSED(cgraph);
+}
+
+static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(!"not implemented");
+
+ UNUSED(backend);
+ UNUSED(plan);
+}
+
+static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(!"not implemented");
+
+ UNUSED(backend);
+ UNUSED(plan);
+}
+
+static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_cuda_set_device(g_main_device);
+
+ ggml_compute_params params = {};
+ params.type = GGML_TASK_COMPUTE;
+ params.ith = 0;
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ assert(node->backend == GGML_BACKEND_GPU);
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j] != nullptr) {
+ assert(node->src[j]->backend == GGML_BACKEND_GPU);
+ }
+ }
+
+ bool ok = ggml_cuda_compute_forward(&params, node);
+ if (!ok) {
+ fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
+
+#if 0
+ if (node->type == GGML_TYPE_F32) {
+ cudaDeviceSynchronize();
+ std::vector<float> tmp(ggml_nelements(node), 0.0f);
+ cudaMemcpy(tmp.data(), node->data, ggml_nelements(node)*sizeof(float), cudaMemcpyDeviceToHost);
+ printf("\n%s (%s) (%s %s) (%s %s): ", node->name, ggml_op_name(node->op),
+ ggml_type_name(node->src[0]->type),
+ node->src[1] ? ggml_type_name(node->src[1]->type) : "none",
+ node->src[0]->name,
+ node->src[1] ? node->src[1]->name : "none");
+ double sum = 0.0;
+ double sq_sum = 0.0;
+ for (int i = 0; i < ggml_nelements(node); i++) {
+ printf("%f ", tmp[i]);
+ sum += tmp[i];
+ sq_sum += tmp[i]*tmp[i];
+ }
+ printf("\n");
+ printf("sum: %f, ", sum);
+ printf("sq_sum: %f\n", sq_sum);
+ }
+#endif
+ }
+
+ UNUSED(backend);
+}
+
+static ggml_backend_i cuda_backend_i = {
+ /* .get_name = */ ggml_backend_cuda_name,
+ /* .free = */ ggml_backend_cuda_free,
+ /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cuda_get_alignment,
+ /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
+ /* .synchronize = */ ggml_backend_cuda_synchronize,
+ /* .cpy_tensor_from = */ nullptr,
+ /* .cpy_tensor_to = */ nullptr,
+ /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create,
+ /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free,
+ /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute,
+ /* .graph_compute = */ ggml_backend_cuda_graph_compute,
+ /* .supports_op = */ nullptr,
+};
+
+ggml_backend_t ggml_backend_cuda_init() {
+ ggml_init_cublas(); // TODO: remove from ggml.c
+
+ ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;
+
+ ggml_backend_t cuda_backend = new ggml_backend {
+ /* .interface = */ cuda_backend_i,
+ /* .context = */ ctx
+ };
+
+ return cuda_backend;
+}