summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-11 10:32:41 +0300
committerGitHub <noreply@github.com>2024-05-11 10:32:41 +0300
commit9cb317f77e53067f7a138cc89ef7657148eae8e6 (patch)
tree3ba1d2d80d1d7c8b4ab01f6396a3febaae26e91b /ggml-cuda
parente849648888a11de13aaaa4cb2eda3f5a9c7b444d (diff)
ggml : full ALiBi support (#7192)
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/alibi.cu63
-rw-r--r--ggml-cuda/alibi.cuh5
-rw-r--r--ggml-cuda/fattn.cu72
-rw-r--r--ggml-cuda/softmax.cu55
4 files changed, 83 insertions, 112 deletions
diff --git a/ggml-cuda/alibi.cu b/ggml-cuda/alibi.cu
deleted file mode 100644
index 6c7f1fd9..00000000
--- a/ggml-cuda/alibi.cu
+++ /dev/null
@@ -1,63 +0,0 @@
-#include "alibi.cuh"
-
-static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
- const int n_heads_log2_floor, const float m0, const float m1) {
- const int col = blockDim.x*blockIdx.x + threadIdx.x;
-
- if (col >= ncols) {
- return;
- }
-
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
- const int i = row*ncols + col;
-
- const int k = row/k_rows;
-
- float m_k;
- if (k < n_heads_log2_floor) {
- m_k = powf(m0, k + 1);
- } else {
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
-
- dst[i] = col * m_k + x[i];
-}
-
-static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
- const int k_rows, const int n_heads_log2_floor, const float m0,
- const float m1, cudaStream_t stream) {
- const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
- const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
- const dim3 block_nums(num_blocks_x, nrows, 1);
- alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
-}
-
-void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
- const int64_t nrows = ggml_nrows(src0);
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_head = ((int32_t *) dst->op_params)[1];
- float max_bias;
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
- //GGML_ASSERT(ne01 + n_past == ne00);
- GGML_ASSERT(n_head == ne02);
-
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream);
-}
diff --git a/ggml-cuda/alibi.cuh b/ggml-cuda/alibi.cuh
deleted file mode 100644
index 630adfc7..00000000
--- a/ggml-cuda/alibi.cuh
+++ /dev/null
@@ -1,5 +0,0 @@
-#include "common.cuh"
-
-#define CUDA_ALIBI_BLOCK_SIZE 32
-
-void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu
index 7c486f48..ac5d6672 100644
--- a/ggml-cuda/fattn.cu
+++ b/ggml-cuda/fattn.cu
@@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
@@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);
+ half slopeh = __float2half(1.0f);
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int h = blockIdx.y;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slopeh = __float2half(powf(base, exph));
+ }
+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
- sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+ sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
@@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
+ half slopeh = __float2half(1.0f);
+ half2 slope2 = make_half2(1.0f, 1.0f);
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int h = blockIdx.y;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slopeh = __float2half(powf(base, exph));
+ slope2 = make_half2(slopeh, slopeh);
+ }
+
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
@@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
@@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
- KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
- float scale;
- memcpy(&scale, KQV->op_params, sizeof(float));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale,
+ scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
- float scale;
- memcpy(&scale, KQV->op_params, sizeof(float));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale,
+ scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
- const int32_t precision = KQV->op_params[1];
+ const int32_t precision = KQV->op_params[2];
if (!fp16_mma_available(cc)) {
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu
index 6ed22599..ca85285a 100644
--- a/ggml-cuda/softmax.cu
+++ b/ggml-cuda/softmax.cu
@@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
- float slope = 0.0f;
+ float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
- slope = powf(base, exp);
+ slope = powf(base, exph);
}
extern __shared__ float data_soft_max_f32[];
@@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
- const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
+ const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
@@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
}
template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
- const uint32_t n_head_kv = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
- const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *)src0->data;
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
@@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
@@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
- // positions tensor
- void * src2_d = nullptr;
-
- const bool use_src2 = src2 != nullptr;
-
- if (use_src2) {
- src2_d = (void *)src2->data;
- }
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
- const half * src2_dd = (const half *)src2_d;
- soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
} else {
const float * src1_dd = (const float *)src1_d;
- const float * src2_dd = (const float *)src2_d;
- soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
}
}