summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-28 11:04:19 +0300
committerGitHub <noreply@github.com>2024-05-28 11:04:19 +0300
commit0548a4187f2e53b8fc6d9ff0f4c71988f708ff42 (patch)
tree35ae0e19ecc36169939620b2702fd853c8e8c116 /ggml-cuda
parent9335b969e86a222e247adacedf814d8abfff8847 (diff)
ggml : generalize GGML_OP_CONCAT (#7563)
* ggml : generalize GGML_OP_CONCAT (WIP) ggml-ci * tests : add dim != 2 tests * metal : generalize concat kernel * tests : naming * cuda : generalize concat kernel ggml-ci * sycl : add warning and assert * ggml : fix op params handling * metal : bugfix kernel ggml-ci * ggml : reimplement CPU and Metal * cuda : add asserts ggml-ci * ggml : fix ptrs ggml-ci
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/concat.cu93
1 files changed, 87 insertions, 6 deletions
diff --git a/ggml-cuda/concat.cu b/ggml-cuda/concat.cu
index 2941d2f1..fb9dee8f 100644
--- a/ggml-cuda/concat.cu
+++ b/ggml-cuda/concat.cu
@@ -1,15 +1,68 @@
#include "concat.cuh"
-static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
- // operation
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (nidx < ne00) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne00 +
+ blockIdx.z * ne00 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ (nidx - ne00) +
+ blockIdx.y * (ne0 - ne00) +
+ blockIdx.z * (ne0 - ne00) * gridDim.y;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (blockIdx.y < ne01) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx +
+ (blockIdx.y - ne01) * ne0 +
+ blockIdx.z * ne0 * (gridDim.y - ne01);
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
+
if (blockIdx.z < ne02) { // src0
int offset_src =
nidx +
@@ -25,25 +78,53 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
}
}
-static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2);
- concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+ if (dim == 0) {
+ concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+ return;
+ }
+ if (dim == 1) {
+ concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
+ return;
+ }
+ concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
}
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
+
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- for (int i3 = 0; i3 < dst->ne[3]; i3++) {
- concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream);
+ if (dim != 3) {
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_cuda(
+ src0_d + i3 * (src0->nb[3] / 4),
+ src1_d + i3 * (src1->nb[3] / 4),
+ dst_d + i3 * ( dst->nb[3] / 4),
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
+ }
+ } else {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+
+ CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
}
}