summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu6
1 files changed, 6 insertions, 0 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index d277104d..c30554f0 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -14,6 +14,7 @@
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
+#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpb = prop.sharedMemPerBlock;
+ info.devices[id].nsm = prop.multiProcessorCount;
}
for (int id = 0; id < info.device_count; ++id) {
@@ -2290,6 +2292,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_cuda_flash_attn_ext(ctx, dst);
+ break;
default:
return false;
}
@@ -2564,6 +2569,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_FLASH_ATTN_EXT:
return true;
default:
return false;