summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-cuda.cu5
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu19
-rw-r--r--ggml/src/ggml.c35
3 files changed, 39 insertions, 20 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index ff6e064c..87f80d0c 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2505,11 +2505,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
dst_padded_col_size, next->src[0]->type, stream);
CUDA_CHECK(cudaGetLastError());
- std::vector<char> ids_host(ggml_nbytes(ids));
- const char * ids_dev = (const char *) ids->data;
- CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
local_dst.ne[2] = 1;
auto local_next = *next;
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index f87ebb96..bc26cce4 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -147,10 +147,27 @@ static __global__ void mul_mat_vec_q(
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
int i2 = blockIdx.y;
+ char * cdst = (char *)dst + i2*nb2;
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
+ if (i02 < 0) {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+ constexpr int rows_per_cuda_block = 1;
+#else
+ constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+ const int row0 = rows_per_cuda_block*blockIdx.x;
+ if (threadIdx.y == 0) {
+ dst = (float *)cdst;
+ for (int j = 0; j < ncols_y; ++j) {
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+ dst[j*nrows_dst + row0 + threadIdx.x] = 0;
+ }
+ }
+ }
+ return;
+ }
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
- char * cdst = (char *)dst + i2*nb2;
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 4cd18a28..d82466e0 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -15911,11 +15911,14 @@ static void ggml_compute_forward_get_rows_f16(
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- assert(i01 >= 0 && i01 < ne01);
+ if (i01 >= 0 && i01 < ne01) {
+ ggml_fp16_to_fp32_row(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ } else {
+ memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
+ }
- ggml_fp16_to_fp32_row(
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
}
}
@@ -15952,11 +15955,13 @@ static void ggml_compute_forward_get_rows_bf16(
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- assert(i01 >= 0 && i01 < ne01);
-
- ggml_bf16_to_fp32_row(
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ if (i01 >= 0 && i01 < ne01) {
+ ggml_bf16_to_fp32_row(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ } else {
+ memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
+ }
}
}
@@ -15993,11 +15998,13 @@ static void ggml_compute_forward_get_rows_f32(
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- assert(i01 >= 0 && i01 < ne01);
-
- ggml_vec_cpy_f32(nc,
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
- (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+ if (i01 >= 0 && i01 < ne01) {
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+ } else {
+ memset((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
+ }
}
}