summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu13
1 files changed, 10 insertions, 3 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 8c271230..2e759d43 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -10039,14 +10039,22 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
}
return false;
} break;
+ case GGML_OP_DUP:
+ case GGML_OP_REPEAT:
+ case GGML_OP_CONCAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ if (src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16) {
+ return true;
+ }
+ return false;
+ } break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
- case GGML_OP_REPEAT:
- case GGML_OP_DUP:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
@@ -10063,7 +10071,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
- case GGML_OP_CONCAT:
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
case GGML_OP_PAD: