diff options
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r-- | ggml/src/ggml-cuda.cu | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 87d7e17e..ca57efbd 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1180,19 +1180,22 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const int64_t nb3 = src->nb[3]; const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); + const int64_t rs = ggml_row_size(type, ne0); const int64_t bs = ggml_blck_size(type); int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*ne0/bs) { + if (nb0 == ts && nb1 == rs) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream); } else if (nb0 == ts) { + // TODO: this only works if the row does not contain meta data return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream); } else { for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + void * rd = (void *) (dst_ptr + i1*rs); // pretend the row is a matrix with cols=1 + // TODO: this only works if the row does not contain meta data cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream); if (r != cudaSuccess) { return r; @@ -1441,8 +1444,7 @@ static void ggml_cuda_op_mul_mat( const int64_t i02_divisor = ne12 / ne02; - const size_t src0_ts = ggml_type_size(src0->type); - const size_t src0_bs = ggml_blck_size(src0->type); + const size_t src0_rs = ggml_row_size(src0->type, ne00); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; @@ -1608,7 +1610,7 @@ static void ggml_cuda_op_mul_mat( } // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * ne01*src0_rs; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); |