summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-12-01 10:51:24 +0200
committerGitHub <noreply@github.com>2023-12-01 10:51:24 +0200
commitef47ec18da469423c276b683dd9b5741cee7023e (patch)
treeec3b4780dbe8f629425de499b298e8eadfd1aa4d
parent1d144112c0fbbb4ecc07dbcf4f05a380148bd6de (diff)
ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
-rw-r--r--examples/batched-bench/batched-bench.cpp2
-rw-r--r--ggml-alloc.c2
-rw-r--r--ggml-cuda.cu130
-rw-r--r--ggml-metal.m43
-rw-r--r--ggml-metal.metal186
-rw-r--r--ggml.c77
-rw-r--r--ggml.h8
-rw-r--r--llama.cpp33
8 files changed, 298 insertions, 183 deletions
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp
index 533c55c1..57596ed9 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/examples/batched-bench/batched-bench.cpp
@@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
}
LOG_TEE("\n");
- LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
+ LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
diff --git a/ggml-alloc.c b/ggml-alloc.c
index cdfe4caf..0d4e12ae 100644
--- a/ggml-alloc.c
+++ b/ggml-alloc.c
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
- size_t cur_max = (char*)addr - (char*)alloc->data + size;
+ size_t cur_max = (char*)addr - (char*)alloc->base + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 5b80e4ae..9019a849 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+ }
+ return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+ }
+ return a;
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+ }
+ return x;
+}
+
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}
-static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
- a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
- }
- return a;
-}
-
template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}
-static __device__ __forceinline__ float warp_reduce_sum(float x) {
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- x += __shfl_xor_sync(0xffffffff, x, mask, 32);
- }
- return x;
-}
-
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
}
-// the CUDA soft max implementation differs from the CPU implementation
-// instead of doubles floats are used
-static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
- const int block_size = blockDim.y;
- const int tid = threadIdx.y;
+static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
+ const int tid = threadIdx.x;
+ const int rowx = blockIdx.x;
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+ const int block_size = blockDim.x;
+
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
- const int i = row*ncols + col;
- max_val = max(max_val, x[i]);
+ const int ix = rowx*ncols + col;
+ const int iy = rowy*ncols + col;
+ max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
}
// find the max value in the block
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+ max_val = warp_reduce_max(max_val);
+ if (block_size > WARP_SIZE) {
+ if (warp_id == 0) {
+ buf[lane_id] = -INFINITY;
+ }
+ __syncthreads();
+
+ if (lane_id == 0) {
+ buf[warp_id] = max_val;
+ }
+ __syncthreads();
+
+ max_val = buf[lane_id];
+ max_val = warp_reduce_max(max_val);
}
float tmp = 0.f;
for (int col = tid; col < ncols; col += block_size) {
- const int i = row*ncols + col;
- const float val = expf(x[i] - max_val);
+ const int ix = rowx*ncols + col;
+ const int iy = rowy*ncols + col;
+ const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
- dst[i] = val;
+ dst[ix] = val;
}
- // sum up partial sums
-#pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ // find the sum of exps in the block
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ if (warp_id == 0) {
+ buf[lane_id] = 0.f;
+ }
+ __syncthreads();
+
+ if (lane_id == 0) {
+ buf[warp_id] = tmp;
+ }
+ __syncthreads();
+
+ tmp = buf[lane_id];
+ tmp = warp_reduce_sum(tmp);
}
const float inv_tmp = 1.f / tmp;
for (int col = tid; col < ncols; col += block_size) {
- const int i = row*ncols + col;
+ const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
-static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
- const dim3 block_dims(1, WARP_SIZE, 1);
+static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+ int nth = WARP_SIZE;
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+ const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
- soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+ soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
- soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+ float scale = 1.0f;
+ memcpy(&scale, dst->op_params, sizeof(float));
+
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
- (void) src1;
(void) dst;
- (void) src1_dd;
}
inline void ggml_cuda_op_scale(
diff --git a/ggml-metal.m b/ggml-metal.m
index d52a1c3c..6cfacf64 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
int nth = 32; // SIMD width
if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth < 256) {
+ nth *= 2;
+ }
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
- do {
+ while (nth < ne00 && nth < 1024) {
nth *= 2;
- } while (nth <= ne00 && nth <= 1024);
- nth /= 2;
+ }
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+
+ const float scale = ((float *) dst->op_params)[0];
+
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- const int nth = MIN(512, ne00);
+ int nth = 32; // SIMD width
+
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 5d1357cd..9a79f815 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -39,6 +39,8 @@ typedef struct {
int8_t qs[QK8_0]; // quants
} block_q8_0;
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
// general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient
@@ -180,10 +182,12 @@ kernel void kernel_gelu(
kernel void kernel_soft_max(
device const float * src0,
+ device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
+ constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +198,77 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+ float lmax = -INFINITY;
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
- lmax = MAX(lmax, psrc0[i00]);
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}
- float max = simd_max(lmax);
- if (tiisg == 0) {
- buf[sgitg] = max;
- }
+ // find the max value in the block
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
- }
- }
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- max = buf[0];
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- const float exp_psrc0 = exp(psrc0[i00] - max);
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
- // Remember the result of exp here. exp is expensive, so we really do not
- // wish to compute it twice.
pdst[i00] = exp_psrc0;
}
float sum = simd_sum(lsum);
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] += buf[tpitg + i];
- }
- }
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
- sum = buf[0];
+ const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- pdst[i00] /= sum;
+ pdst[i00] *= inv_sum;
}
}
kernel void kernel_soft_max_4(
device const float * src0,
+ device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
+ constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +279,68 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+ float4 lmax4 = -INFINITY;
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]);
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- float max = simd_max(lmax);
- if (tiisg == 0) {
- buf[sgitg] = max;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
- }
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- max = buf[0];
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
float sum = simd_sum(lsum);
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] += buf[tpitg + i];
- }
- }
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
- sum = buf[0];
+ const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- pdst4[i00] /= sum;
+ pdst4[i00] *= inv_sum;
}
}
@@ -435,14 +447,13 @@ kernel void kernel_rms_norm(
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
- threadgroup float * sum [[threadgroup(0)]],
+ threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
- device const float * x_scalar = (device const float *) x;
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
float4 sumf = 0;
float all_sum = 0;
@@ -453,40 +464,30 @@ kernel void kernel_rms_norm(
}
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
all_sum = simd_sum(all_sum);
- if (tiisg == 0) {
- sum[sgitg] = all_sum;
- }
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- }
- if (tpitg == 0) {
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
- sum[0] += x_scalar[i];
+ if (tiisg == 0) {
+ buf[sgitg] = all_sum;
}
- sum[0] /= ne00;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- const float mean = sum[0];
+ all_sum = buf[tiisg];
+ all_sum = simd_sum(all_sum);
+ }
+
+ const float mean = all_sum/ne00;
const float scale = 1.0f/sqrt(mean + eps);
device float4 * y = (device float4 *) (dst + tgpig*ne00);
- device float * y_scalar = (device float *) y;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
y[i00] = x[i00] * scale;
}
- if (tpitg == 0) {
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
- y_scalar[i00] = x_scalar[i00] * scale;
- }
- }
}
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,7 +577,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
//Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
// giard against the number of rows not being divisible by
diff --git a/ggml.c b/ggml.c
index c522a101..e2687ef4 100644
--- a/ggml.c
+++ b/ggml.c
@@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
bool inplace) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ if (mask) {
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(mask->ne[2] == 1);
+ GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+ }
+
bool is_node = false;
if (a->grad) {
@@ -4835,9 +4845,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ float params[] = { scale };
+ ggml_set_op_params(result, params, sizeof(params));
+
result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
+ result->src[1] = mask;
return result;
}
@@ -4845,13 +4859,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, false);
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
}
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, true);
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale) {
+ return ggml_soft_max_impl(ctx, a, mask, scale, false);
}
// ggml_soft_max_back
@@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
- struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
+ float scale = 1.0f;
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
const int nth = params->nth;
+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
@@ -10575,29 +10602,40 @@ static void ggml_compute_forward_soft_max_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
for (int i1 = ir0; i1 < ir1; i1++) {
- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+ // broadcast the mask across rows
+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+
+ ggml_vec_cpy_f32 (nc, wp, sp);
+ ggml_vec_scale_f32(nc, wp, scale);
+ if (mp) {
+ ggml_vec_acc_f32(nc, wp, mp);
+ }
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(sp[i]));
+ assert(!isnan(wp[i]));
}
#endif
float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, sp);
+ ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = 0.0;
uint16_t scvt;
for (int i = 0; i < nc; i++) {
- if (sp[i] == -INFINITY) {
+ if (wp[i] == -INFINITY) {
dp[i] = 0.0f;
} else {
- // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
- ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
+ // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
sum += (ggml_float)val;
@@ -10622,11 +10660,12 @@ static void ggml_compute_forward_soft_max_f32(
static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
- struct ggml_tensor * dst) {
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
- ggml_compute_forward_soft_max_f32(params, src0, dst);
+ ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
} break;
default:
{
@@ -13863,7 +13902,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX:
{
- ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
@@ -15899,6 +15938,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
+ case GGML_OP_SOFT_MAX:
+ {
+ n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
+
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+ } break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
GGML_ASSERT(node->src[0]->ne[3] == 1);
diff --git a/ggml.h b/ggml.h
index 4d6d4edf..2f6787d4 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1282,6 +1282,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
+ // fused soft_max(a*scale + mask)
+ // mask is optional
+ GGML_API struct ggml_tensor * ggml_soft_max_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale);
+
GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
diff --git a/llama.cpp b/llama.cpp
index 1e00ea4a..e74fd723 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -3704,22 +3704,28 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- kq = ggml_scale(ctx, kq, kq_scale);
- cb(kq, "kq_scaled", il);
-
if (max_alibi_bias > 0.0f) {
- // TODO: n_head or n_head_kv
- // TODO: K-shift is likely not working
- // TODO: change to ggml_add
- kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
- cb(kq, "kq_scaled_alibi", il);
- }
+ // temporary branch until we figure out how to handle ggml_alibi through ggml_add
+ kq = ggml_scale(ctx, kq, kq_scale);
+ cb(kq, "kq_scaled", il);
+
+ if (max_alibi_bias > 0.0f) {
+ // TODO: n_head or n_head_kv
+ // TODO: K-shift is likely not working
+ // TODO: change to ggml_add
+ kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
+ cb(kq, "kq_scaled_alibi", il);
+ }
- kq = ggml_add(ctx, kq, kq_mask);
- cb(kq, "kq_masked", il);
+ kq = ggml_add(ctx, kq, kq_mask);
+ cb(kq, "kq_masked", il);
- kq = ggml_soft_max(ctx, kq);
- cb(kq, "kq_soft_max", il);
+ kq = ggml_soft_max(ctx, kq);
+ cb(kq, "kq_soft_max", il);
+ } else {
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
+ cb(kq, "kq_soft_max_ext", il);
+ }
// split cached v into n_head heads
struct ggml_tensor * v =
@@ -5041,6 +5047,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
{ "kq_masked", OFFLOAD_FUNC_KQ },
{ "kq_soft_max", OFFLOAD_FUNC_V },
+ { "kq_soft_max_ext", OFFLOAD_FUNC_V },
{ "v", OFFLOAD_FUNC_V },
{ "kqv", OFFLOAD_FUNC_V },
{ "kqv_merged", OFFLOAD_FUNC_V },