summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r--ggml/src/ggml-cuda.cu12
1 files changed, 7 insertions, 5 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 87d7e17e..ca57efbd 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -1180,19 +1180,22 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const int64_t nb3 = src->nb[3];
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
+ const int64_t rs = ggml_row_size(type, ne0);
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
- if (nb0 == ts && nb1 == ts*ne0/bs) {
+ if (nb0 == ts && nb1 == rs) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
} else if (nb0 == ts) {
+ // TODO: this only works if the row does not contain meta data
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
} else {
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ void * rd = (void *) (dst_ptr + i1*rs);
// pretend the row is a matrix with cols=1
+ // TODO: this only works if the row does not contain meta data
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
if (r != cudaSuccess) {
return r;
@@ -1441,8 +1444,7 @@ static void ggml_cuda_op_mul_mat(
const int64_t i02_divisor = ne12 / ne02;
- const size_t src0_ts = ggml_type_size(src0->type);
- const size_t src0_bs = ggml_blck_size(src0->type);
+ const size_t src0_rs = ggml_row_size(src0->type, ne00);
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
@@ -1608,7 +1610,7 @@ static void ggml_cuda_op_mul_mat(
}
// for split tensors the data begins at i0 == i0_offset_low
- char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+ char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * ne01*src0_rs;
float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);