diff options
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/ggml-cuda.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 19 | ||||
-rw-r--r-- | ggml/src/ggml.c | 35 |
3 files changed, 39 insertions, 20 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ff6e064c..87f80d0c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2505,11 +2505,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dst_padded_col_size, next->src[0]->type, stream); CUDA_CHECK(cudaGetLastError()); - std::vector<char> ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - local_dst.ne[2] = 1; auto local_next = *next; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index f87ebb96..bc26cce4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -147,10 +147,27 @@ static __global__ void mul_mat_vec_q( const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; + char * cdst = (char *)dst + i2*nb2; int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int rows_per_cuda_block = 1; +#else + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + const int row0 = rows_per_cuda_block*blockIdx.x; + if (threadIdx.y == 0) { + dst = (float *)cdst; + for (int j = 0; j < ncols_y; ++j) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + dst[j*nrows_dst + row0 + threadIdx.x] = 0; + } + } + } + return; + } const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; - char * cdst = (char *)dst + i2*nb2; mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4cd18a28..d82466e0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -15911,11 +15911,14 @@ static void ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 >= 0 && i01 < ne01) { + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } - ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -15952,11 +15955,13 @@ static void ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_bf16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + if (i01 >= 0 && i01 < ne01) { + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } } } @@ -15993,11 +15998,13 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + if (i01 >= 0 && i01 < ne01) { + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } else { + memset((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + } } } |