summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r--ggml/src/ggml-cuda.cu16
1 files changed, 12 insertions, 4 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index bc960678..85df0694 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2133,7 +2133,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
@@ -2162,7 +2163,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
@@ -2301,7 +2303,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
@@ -2362,7 +2365,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+ if (row_id_i < 0 || row_id_i >= n_as) continue;
+ //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
@@ -2637,6 +2641,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
+ case GGML_OP_ARGSORT_THRESH:
+ ggml_cuda_op_argsort_thresh(ctx, dst);
+ break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
@@ -3252,6 +3259,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
+ case GGML_OP_ARGSORT_THRESH:
case GGML_OP_ACC:
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE: