diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-31 12:05:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-31 12:05:27 +0100 |
commit | 52874c5d21819bd63cc4c500f2fb1be435d16b5e (patch) | |
tree | bff705dd887d124958d1ad3c224de9ee9732de1a /ggml/src/ggml-cuda.cu | |
parent | 5ad6439486e5bfdd8e34213a36beb56b74842bbe (diff) |
Faster MoE inference (#112)
* multi_sdd: WIP
* multi_sdd: CPU works
* multi_add: CUDA
* multi_add: simplify
* multi_add: Metal
* Metal: speed up mul_mat_id
For the Granite-1B MoE model PP-512 goes from
156 t/s to 890 t/s, so nearly a 6X speedup!
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r-- | ggml/src/ggml-cuda.cu | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6759e202..e38e9568 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_MULTI_ADD: + ggml_cuda_op_multi_add(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -2607,6 +2610,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); #endif } + if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. @@ -2927,6 +2938,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: |