summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-12-01 10:51:24 +0200
committerGitHub <noreply@github.com>2023-12-01 10:51:24 +0200
commitef47ec18da469423c276b683dd9b5741cee7023e (patch)
treeec3b4780dbe8f629425de499b298e8eadfd1aa4d /ggml-cuda.cu
parent1d144112c0fbbb4ecc07dbcf4f05a380148bd6de (diff)
ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu130
1 files changed, 87 insertions, 43 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 5b80e4ae..9019a849 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+ }
+ return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+ }
+ return a;
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+ }
+ return x;
+}
+
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}
-static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
- a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
- }
- return a;
-}
-
template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}
-static __device__ __forceinline__ float warp_reduce_sum(float x) {
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- x += __shfl_xor_sync(0xffffffff, x, mask, 32);
- }
- return x;
-}
-
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
}
-// the CUDA soft max implementation differs from the CPU implementation
-// instead of doubles floats are used
-static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
- const int block_size = blockDim.y;
- const int tid = threadIdx.y;
+static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
+ const int tid = threadIdx.x;
+ const int rowx = blockIdx.x;
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+ const int block_size = blockDim.x;
+
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
- const int i = row*ncols + col;
- max_val = max(max_val, x[i]);
+ const int ix = rowx*ncols + col;
+ const int iy = rowy*ncols + col;
+ max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
}
// find the max value in the block
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+ max_val = warp_reduce_max(max_val);
+ if (block_size > WARP_SIZE) {
+ if (warp_id == 0) {
+ buf[lane_id] = -INFINITY;
+ }
+ __syncthreads();
+
+ if (lane_id == 0) {
+ buf[warp_id] = max_val;
+ }
+ __syncthreads();
+
+ max_val = buf[lane_id];
+ max_val = warp_reduce_max(max_val);
}
float tmp = 0.f;
for (int col = tid; col < ncols; col += block_size) {
- const int i = row*ncols + col;
- const float val = expf(x[i] - max_val);
+ const int ix = rowx*ncols + col;
+ const int iy = rowy*ncols + col;
+ const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
- dst[i] = val;
+ dst[ix] = val;
}
- // sum up partial sums
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ // find the sum of exps in the block
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ if (warp_id == 0) {
+ buf[lane_id] = 0.f;
+ }
+ __syncthreads();
+
+ if (lane_id == 0) {
+ buf[warp_id] = tmp;
+ }
+ __syncthreads();
+
+ tmp = buf[lane_id];
+ tmp = warp_reduce_sum(tmp);
}
const float inv_tmp = 1.f / tmp;
for (int col = tid; col < ncols; col += block_size) {
- const int i = row*ncols + col;
+ const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
-static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
- const dim3 block_dims(1, WARP_SIZE, 1);
+static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+ int nth = WARP_SIZE;
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+ const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
- soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+ soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
- soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+ float scale = 1.0f;
+ memcpy(&scale, dst->op_params, sizeof(float));
+
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
- (void) src1;
(void) dst;
- (void) src1_dd;
}
inline void ggml_cuda_op_scale(