diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d619f5da..eb244f40 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -302,3 +302,31 @@ void ggml_init_cublas(void) { // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } } + +cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) { + const uint64_t ne0 = src->ne[0]; + const uint64_t ne1 = src->ne[1]; + const uint64_t nb0 = src->nb[0]; + const uint64_t nb1 = src->nb[1]; + const uint64_t nb2 = src->nb[2]; + const uint64_t nb3 = src->nb[3]; + const enum ggml_type type = src->type; + const size_t ts = ggml_type_size(type); + const size_t bs = ggml_blck_size(type); + + const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) ((char *) dst + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream); + if (r != cudaSuccess) return r; + } + return cudaSuccess; + } +} |