summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r--ggml/src/ggml-cuda.cu4
1 files changed, 3 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 87f80d0c..ef73ee7d 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -3391,6 +3391,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->op == GGML_OP_MOE_FUSED_UP_GATE ? op->src[2] : op->src[1];
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
+ printf("%s: returning false for GGML_OP_MOE_FUSED_UP_GATE because src0->type != src1->type\n", __func__);
return false;
}
//==================================================================
@@ -3399,6 +3400,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
//}
//==================================================================
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
+ printf("%s: returning false for op %d because (case 1)\n", __func__, (int)op->op);
return false;
}
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
@@ -3621,7 +3623,7 @@ GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const
const int min_batch_size = 32;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
- (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+ (op->ne[2] >= min_batch_size && (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MOE_FUSED_UP_GATE));
GGML_UNUSED(backend);
}