summaryrefslogtreecommitdiff
path: root/ggml-cuda/cpy.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda/cpy.cu')
-rw-r--r--ggml-cuda/cpy.cu29
1 files changed, 29 insertions, 0 deletions
diff --git a/ggml-cuda/cpy.cu b/ggml-cuda/cpy.cu
index 16d9c8ff..12d741f0 100644
--- a/ggml-cuda/cpy.cu
+++ b/ggml-cuda/cpy.cu
@@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
ggml_cuda_cpy(ctx, src0, dst);
}
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f32>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+ return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+ } else {
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+}
+