summaryrefslogtreecommitdiff
path: root/ggml-cuda/fattn-common.cuh
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda/fattn-common.cuh')
-rw-r--r--ggml-cuda/fattn-common.cuh47
1 files changed, 47 insertions, 0 deletions
diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh
new file mode 100644
index 00000000..33f64069
--- /dev/null
+++ b/ggml-cuda/fattn-common.cuh
@@ -0,0 +1,47 @@
+#define FATTN_KQ_STRIDE 256
+#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+template<int D, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_combine_results(
+ const float * __restrict__ VKQ_parts,
+ const float2 * __restrict__ VKQ_meta,
+ float * __restrict__ dst) {
+ VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+ VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
+ dst += D * gridDim.y*blockIdx.x;
+
+ const int tid = threadIdx.x;
+ __builtin_assume(tid < D);
+
+ __shared__ float2 meta[parallel_blocks];
+ if (tid < 2*parallel_blocks) {
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+ }
+
+ __syncthreads();
+
+ float kqmax = meta[0].x;
+#pragma unroll
+ for (int l = 1; l < parallel_blocks; ++l) {
+ kqmax = max(kqmax, meta[l].x);
+ }
+
+ float VKQ_numerator = 0.0f;
+ float VKQ_denominator = 0.0f;
+#pragma unroll
+ for (int l = 0; l < parallel_blocks; ++l) {
+ const float diff = meta[l].x - kqmax;
+ const float KQ_max_scale = expf(diff);
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+ VKQ_denominator += KQ_max_scale * meta[l].y;
+ }
+
+ dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+}