summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
Diffstat (limited to 'ggml')
-rw-r--r--ggml/src/ggml-cuda/concat.cu30
-rw-r--r--ggml/src/ggml.c20
2 files changed, 43 insertions, 7 deletions
diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
index dac10ec3..4bde6d69 100644
--- a/ggml/src/ggml-cuda/concat.cu
+++ b/ggml/src/ggml-cuda/concat.cu
@@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
- if (dim != 3) {
+ if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+ CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+ } else {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
@@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
- } else {
- const size_t size0 = ggml_nbytes(src0);
- const size_t size1 = ggml_nbytes(src1);
-
- CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
- CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
}
+
+ //if (dim != 3) {
+ // for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ // concat_f32_cuda(
+ // src0_d + i3 * (src0->nb[3] / 4),
+ // src1_d + i3 * (src1->nb[3] / 4),
+ // dst_d + i3 * ( dst->nb[3] / 4),
+ // src0->ne[0], src0->ne[1], src0->ne[2],
+ // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
+ // }
+ //} else {
+ // const size_t size0 = ggml_nbytes(src0);
+ // const size_t size1 = ggml_nbytes(src1);
+
+ // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+ //}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 80dd25ff..91c0c5db 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -12627,6 +12627,26 @@ static void ggml_compute_forward_concat_f32(
GGML_ASSERT(dim >= 0 && dim < 4);
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) &&
+ (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
+ // simply copy the data
+ const int64_t size_src_0 = ggml_nbytes(src0);
+ const int64_t size_src_1 = ggml_nbytes(src1);
+ const int64_t block_size = 4096;
+ const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size;
+ for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) {
+ const int64_t start = i_block*block_size;
+ if (start < size_src_0) {
+ int64_t copy_size = MIN(block_size, size_src_0 - start);
+ memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size);
+ } else {
+ int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start);
+ memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size);
+ }
+ }
+ return;
+ }
+
int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];