summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJiahao Li <liplus17@163.com>2023-09-04 14:53:30 +0800
committerGitHub <noreply@github.com>2023-09-04 08:53:30 +0200
commit35195689cd835464779c247b1c22ab9247418fd1 (patch)
tree108330cdd14f1e9a3aad22d938113e3dfb3ba027
parentcf9b08485c4c2d4d945c6e74fe20f273a38b6104 (diff)
2x faster (rms) norm cuda kernels (3.7% e2e improvement) (#2985)
* 2x faster (rms) norm cuda kernels * Fix code style
-rw-r--r--ggml-cuda.cu89
1 files changed, 66 insertions, 23 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 8357f32f..d2dbf824 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -464,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+ }
+ return a;
+}
+
+template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const float eps = 1e-5f;
- float mean = 0.0f;
- float var = 0.0f;
+ float2 mean_var = make_float2(0.f, 0.f);
- for (int col = tid; col < ncols; col += WARP_SIZE) {
+ for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
- mean += xi;
- var += xi * xi;
+ mean_var.x += xi;
+ mean_var.y += xi * xi;
}
// sum up partial sums
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
- var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+ mean_var = warp_reduce_sum(mean_var);
+ if (block_size > WARP_SIZE) {
+ __shared__ float2 s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = mean_var;
+ }
+ __syncthreads();
+ mean_var = s_sum[lane_id];
+ mean_var = warp_reduce_sum(mean_var);
}
- mean /= ncols;
- var = var / ncols - mean * mean;
- const float inv_var = rsqrtf(var + eps);
+ const float mean = mean_var.x / ncols;
+ const float var = mean_var.y / ncols - mean * mean;
+ const float inv_std = rsqrtf(var + eps);
- for (int col = tid; col < ncols; col += WARP_SIZE) {
- dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
}
}
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+ }
+ return x;
+}
+
+template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
- for (int col = tid; col < ncols; col += WARP_SIZE) {
+ for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
// sum up partial sums
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
- for (int col = tid; col < ncols; col += WARP_SIZE) {
+ for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
@@ -4203,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
- const dim3 block_dims(WARP_SIZE, 1, 1);
- norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ }
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ }
}
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {