diff options
Diffstat (limited to 'ggml')
-rw-r--r-- | ggml/include/ggml.h | 13 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/norm.cu | 75 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/norm.cuh | 2 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 35 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 51 | ||||
-rw-r--r-- | ggml/src/ggml.c | 138 |
7 files changed, 316 insertions, 2 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1a4a516c..ab6d172d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -480,6 +480,7 @@ extern "C" { GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, GGML_OP_GROUP_NORM, + GGML_OP_FUSED_RMS_NORM, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -1159,6 +1160,18 @@ extern "C" { struct ggml_tensor * a, float eps); + GGML_API struct ggml_tensor * ggml_fused_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + + GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + // group normalize along ne0*ne1*n_groups // used in stable-diffusion GGML_API struct ggml_tensor * ggml_group_norm( diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 056ca4a4..cf053559 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2248,6 +2248,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_RMS_NORM: ggml_cuda_op_rms_norm(ctx, dst); break; + case GGML_OP_FUSED_RMS_NORM: + ggml_cuda_op_fused_rms_norm(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); @@ -2871,6 +2874,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SOFTCAP: case GGML_OP_SQR: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f..7e670912 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } +template <int block_size> +static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * y[col] * x[row*ncols + col]; + } +} + static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, + const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + if (!dst->src[1]) { + ggml_cuda_op_rms_norm(ctx, dst); + return; + } + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 431a8f74..e4f9ee82 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 83bd76f9..b3f6e60c 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -104,6 +104,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, @@ -613,6 +614,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, get_rows_iq6_k, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, fused_rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); @@ -884,6 +886,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_GROUP_NORM: return ctx->support_simdgroup_reduction; case GGML_OP_NORM: @@ -2608,6 +2611,38 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_FUSED_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nrows(src1) == 1); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&eps length:sizeof( float) atIndex:5]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_GROUP_NORM: { GGML_ASSERT(ne00 % 4 == 0); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f9c88a37..d7af1800 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1038,6 +1038,57 @@ kernel void kernel_rms_norm( } } +kernel void kernel_fused_rms_norm( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } + + const float mean = all_sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device float4 * y = (device float4 *) (dst + tgpig*ne00); + device float4 * z = (device float4 *)src1; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = x[i00] * z[i00] * scale; + } +} + kernel void kernel_group_norm( device const float * src0, device float * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 45fddca5..d562002e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3144,6 +3144,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM", "RMS_NORM_BACK", "GROUP_NORM", + "FUSED_RMS_NORM", "MUL_MAT", "MUL_MAT_ID", @@ -3207,7 +3208,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3234,6 +3235,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm(x)", "rms_norm_back(x)", "group_norm(x)", + "fused_rms_norm(x)", "X*Y", "X[i]*Y", @@ -3297,7 +3299,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5737,6 +5739,57 @@ struct ggml_tensor * ggml_rms_norm_inplace( return ggml_rms_norm_impl(ctx, a, eps, true); } +static struct ggml_tensor * ggml_fused_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps, + bool inplace) { + + if (!b) { + return ggml_rms_norm_impl(ctx, a, eps, inplace); + } + + if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) { + struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace); + result = ggml_mul_impl(ctx, result, b, inplace); + return result; + } + + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_FUSED_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + return ggml_fused_rms_norm_impl(ctx, a, b, eps, false); +} + +struct ggml_tensor * ggml_fused_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + return ggml_fused_rms_norm_impl(ctx, a, b, eps, true); +} + // ggml_rms_norm_back struct ggml_tensor * ggml_rms_norm_back( @@ -12455,6 +12508,78 @@ static void ggml_compute_forward_rms_norm( } } +static void ggml_compute_forward_fused_rms_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (!src1) { + ggml_compute_forward_rms_norm_f32(params, dst); + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps > 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data); + ggml_vec_scale_f32(ne00, y, scale); + + } + } + } +} + +static void ggml_compute_forward_fused_rms_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_rms_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_compute_forward_rms_norm_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -17708,6 +17833,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rms_norm(params, tensor); } break; + case GGML_OP_FUSED_RMS_NORM: + { + ggml_compute_forward_fused_rms_norm(params, tensor); + } break; case GGML_OP_RMS_NORM_BACK: { ggml_compute_forward_rms_norm_back(params, tensor); @@ -18398,6 +18527,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_FUSED_RMS_NORM: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_RMS_NORM_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -19465,6 +19598,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_GROUP_NORM: case GGML_OP_CONCAT: |