diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-20 17:15:47 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-20 17:15:47 +0300 |
commit | d259a50ca6fd3a0821abe6a16b73c0b19c5b4651 (patch) | |
tree | 4f83bbbbbbd9323192d8c0bceb51de5b0fb620c2 | |
parent | a325745000114a43c1546323f91720db503ed0a9 (diff) |
Fused soft cap and SIMD-ified GeLU (#9)
* Softcap: WIP
Fuses scale + tanh + scale as used for softcaping in some
models.
Just CPU for now. ~1.4% for PP-512 on Gemma2-9b, no effect on TG.
Somewhat surprisingly the improvement does not increase as I
go to longer contexts. Gemma2 does softcap on K*Q, which grows
quadratically with context length, so I would have thought
the benefit from fusing scale, tanh, scale would increase.
But no, no luck.
* softcap: CUDA
* softcap: CUDA
~1% speedup for Gemma2-9b
* softcap: Metal and NEON
About 1% speedup.
* Simdified gelu
Gives ~1% speedup for Gemma2-9b prompt processing on AVX512/AVX2.
It looks like the gelu operation is memory bound on my CPU's
after SIMD-ifying it. By not using the 128 kb gelu lookup table
we gain a small advantage.
On the M2-Max the lookup table is slightly faster than the SIMD
version, so left the lookup table for ARM_NEON.
* softcap, tanh: avoid NaNs for large arguments (AVX2, AVX512)
Not that I have encountered this in practice, but just to be sure.
This does it for AVX512 and AVX2, still need a guard for ARM_NEON.
* llama-bench: add ability to turn off warmup runs
So we don't need to wait forever on, e.g., benchmarks involving
long contexts.
* softcap, tanh: avoid NaNs for large arguments (NEON)
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 24 | ||||
-rw-r--r-- | ggml/include/ggml.h | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/softcap.cu | 32 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/softcap.cuh | 5 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 32 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 18 | ||||
-rw-r--r-- | ggml/src/ggml.c | 337 | ||||
-rw-r--r-- | src/llama.cpp | 20 |
9 files changed, 437 insertions, 50 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 42918bfc..813d7bae 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -237,6 +237,7 @@ struct cmd_params { ggml_numa_strategy numa; int reps; bool verbose; + bool warmup; output_formats output_format; output_formats output_format_stderr; }; @@ -263,6 +264,7 @@ static const cmd_params cmd_params_defaults = { /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* verbose */ false, + /* warmup */ true, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -295,6 +297,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -o, --output <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format)); printf(" -oe, --output-err <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); + printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -338,6 +341,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { params.output_format_stderr = cmd_params_defaults.output_format_stderr; params.reps = cmd_params_defaults.reps; params.numa = cmd_params_defaults.numa; + params.warmup = cmd_params_defaults.warmup; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -555,6 +559,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = !output_format_from_str(argv[i], params.output_format_stderr); } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; + } else if (arg == "-w" || arg == "--warmup") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.warmup = std::stoi(argv[i]); } else { invalid_param = true; break; @@ -1429,12 +1439,14 @@ int main(int argc, char ** argv) { llama_kv_cache_clear(ctx); // warmup run - if (t.n_prompt > 0) { - //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); - } - if (t.n_gen > 0) { - test_gen(ctx, 1, 0, t.n_threads); + if (params.warmup) { + if (t.n_prompt > 0) { + //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + } + if (t.n_gen > 0) { + test_gen(ctx, 1, 0, t.n_threads); + } } for (int i = 0; i < params.reps; i++) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 026993db..17d3cb1a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -514,6 +514,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_SOFTCAP, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -1223,6 +1224,19 @@ extern "C" { struct ggml_tensor * a, float s); + GGML_API struct ggml_tensor * ggml_softcap( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_softcap_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index f594cd26..73ab0b73 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -24,6 +24,7 @@ #include "ggml-cuda/quantize.cuh" #include "ggml-cuda/rope.cuh" #include "ggml-cuda/scale.cuh" +#include "ggml-cuda/softcap.cuh" #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/tsembd.cuh" @@ -2261,6 +2262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SCALE: ggml_cuda_op_scale(ctx, dst); break; + case GGML_OP_SOFTCAP: + ggml_cuda_op_softcap(ctx, dst); + break; case GGML_OP_SQR: ggml_cuda_op_sqr(ctx, dst); break; @@ -2865,6 +2869,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: + case GGML_OP_SOFTCAP: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_CLAMP: diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu new file mode 100644 index 00000000..499025d1 --- /dev/null +++ b/ggml/src/ggml-cuda/softcap.cu @@ -0,0 +1,32 @@ +#include "softcap.cuh" + +static __global__ void softcap_f32(const float * x, float * dst, float s_before, float s_after, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + float xi = s_before*x[i]; + dst[i] = s_after * tanh(xi); +} + +static void softcap_f32_cuda(const float * x, float * dst, float s_before, float s_after, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; + softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, s_before, s_after, k); +} + +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scales[2]; + memcpy(scales, dst->op_params, sizeof(scales)); + + softcap_f32_cuda(src0_d, dst_d, scales[0], scales[1], ggml_nelements(src0), stream); +} diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh new file mode 100644 index 00000000..2b875bfb --- /dev/null +++ b/ggml/src/ggml-cuda/softcap.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_SOFTCAP_BLOCK_SIZE 256 + +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 292f9ac7..1e940c5b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -51,6 +51,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_REPEAT_I16, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_SOFTCAP, + GGML_METAL_KERNEL_TYPE_SOFTCAP_4, GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, @@ -554,6 +556,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP, softcap, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP_4, softcap_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); @@ -867,6 +871,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_SQR: case GGML_OP_SUM_ROWS: return true; + case GGML_OP_SOFTCAP: + return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: @@ -1413,6 +1419,32 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scales[2]; + memcpy(scales, dst->op_params, sizeof(scales)); + + int64_t n = ggml_nelements(dst); + + id<MTLComputePipelineState> pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; + [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_CLAMP: { id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 904639a5..2a0e84a6 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -289,6 +289,24 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_softcap( + device const float * src0, + device float * dst, + constant float & s_before, + constant float & s_after, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before); +} + +kernel void kernel_softcap_4( + device const float4 * src0, + device float4 * dst, + constant float & s_before, + constant float & s_after, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before); +} + kernel void kernel_clamp( device const float * src0, device float * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e7f1ae61..9b877bab 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2436,44 +2436,14 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; +static const float GELU_COEF_A = 0.044715f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = ggml_table_gelu_f16[i16[i]]; - } -} - -#ifdef GGML_GELU_FP16 -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - if (x[i] <= -10.0f) { - y[i] = 0.0f; - } else if (x[i] >= 10.0f) { - y[i] = x[i]; - } else { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); - } - } -} -#else -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]); - } -} -#endif - inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } @@ -2555,7 +2525,33 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); const float32x4_t exp_two_x = ggml_v_expf(two_x); - return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); + //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); +} + +inline static float32x4_t ggml_v_softcap(float32x4_t x, float32x4_t s_before, float32x4_t s_after) { + return vmulq_f32(s_after, ggml_v_tanh(vmulq_f32(x, s_before))); + //const float32x4_t one = vdupq_n_f32(1.0f); + //const float32x4_t two_x = vmulq_f32(x, s_before); + //const float32x4_t exp_two_x = ggml_v_expf(two_x); + //const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + //return vmulq_f32(th, s_after); +} + + +// Slower than lookup on my M2-Max +inline static float32x4_t ggml_v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { + const float32x4_t one = vdupq_n_f32(1.0f); + //float32x4_t arg = vaddq_f32(one, vmulq_f32(vmulq_f32(x, x), c1)); + float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); + arg = vmulq_f32(arg, vmulq_f32(x, c2)); + float32x4_t exp_arg = ggml_v_expf(arg); + float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one))); + uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + return vbslq_f32(mask, x, gelu); + //return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(x), mask), vbicq_u32(vreinterpretq_u32_f32(gelu), mask))); } #elif defined(__AVX512F__) && defined(__AVX512DQ__) @@ -2604,7 +2600,27 @@ inline static __m512 ggml_v_silu(__m512 x) { inline static __m512 ggml_v_tanh(__m512 x) { const __m512 one = _mm512_set1_ps(1.0f); const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f))); - return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ); + const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_blend_ps(mask, res, one); +} + +inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before)); + const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mul_ps(th, s_after); +} + +inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) { + const __m512 one = _mm512_set1_ps(1.0f); + __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); + //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1)); + arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x)); + const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ); + const __m512 exp_arg = ggml_v_expf(arg); + const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)); + return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one)); } #elif defined(__AVX2__) && defined(__FMA__) @@ -2665,7 +2681,27 @@ inline static __m256 ggml_v_silu(__m256 x) { inline static __m256 ggml_v_tanh(__m256 x) { const __m256 one = _mm256_set1_ps(1.0f); const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f))); - return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); +} + +inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) { + return _mm256_mul_ps(_mm256_set1_ps(s_after), ggml_v_tanh(_mm256_mul_ps(x, _mm256_set1_ps(s_before)))); + //const __m256 one = _mm256_set1_ps(1.0f); + //const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before))); + //const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + //return _mm256_mul_ps(th, _mm256_set1_ps(s_after)); +} + +inline static __m256 ggml_v_gelu(__m256 x, __m256 c1, __m256 c2) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1)); + arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2)); + __m256 exp_arg = ggml_v_expf(arg); + __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one))); + return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu)); } #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON @@ -2728,6 +2764,13 @@ inline static __m128 ggml_v_tanh(__m128 x) { return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one)); } +inline static __m128 ggml_v_softcap(__m128 x, float s_before, float s_after) { + const __m128 one = _mm_set1_ps(1.0f); + const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f*s_before))); + const __m128 th = _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one)); + return _mm_mul_ps(th, _mm_set1_ps(s_after)); +} + #endif // __ARM_NEON / __AVX2__ / __SSE2__ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { @@ -2778,6 +2821,108 @@ static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + //for (; i + 63 < n; i += 64) { + // __m512 x1 = _mm512_loadu_ps(x + i); + // __m512 x2 = _mm512_loadu_ps(x + i + 16); + // __m512 x3 = _mm512_loadu_ps(x + i + 32); + // __m512 x4 = _mm512_loadu_ps(x + i + 48); + // _mm512_storeu_ps(x + i + 0, ggml_v_softcap(x1, vs_before, vs_after)); + // _mm512_storeu_ps(x + i + 16, ggml_v_softcap(x2, vs_before, vs_after)); + // _mm512_storeu_ps(x + i + 32, ggml_v_softcap(x3, vs_before, vs_after)); + // _mm512_storeu_ps(x + i + 48, ggml_v_softcap(x4, vs_before, vs_after)); + //} + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(x + i, ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + float32x4_t vs_before = vdupq_n_f32(s_before); + float32x4_t vs_after = vdupq_n_f32(s_after); + for (; i + 3 < n; i += 4) { + vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after)); + } +#endif + for (; i < n; ++i) { + x[i] = s_after*tanhf(x[i]*s_before); + } +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_f16[i16[i]]; + } +} + +// +// On my AVX512 (Ryzen-7950X) and AVX2 (Ryzen-5975WX) computing gelu directly +// via SIMD instructions is faster than the fp16-based lookup table. +// On my M2-Max CPU the lookup table is slightly faster than the SIMD version, +// hence we use the SIMD version only if GGML_GELU_FP16 is not defined. +// We do not run into numerical issues for large or small arguments because +// 0.5f * (1 + tanhf(arg)) +// is computed as +// exp(2*arg)/(exp(2*arg) + 1) +// The ggml_v_expf functions flushes to zero for large enough negative +// arguments, so the above becomes zero. ggml_v_expf returns INFINITY +// for large positive arguments, so we would get a NaN if we did nothing. But in the +// ggml_v_gelu SIMD implementations we override the gelu result with the +// input argument when the argument is greater than 10, so it is all good. +// +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2)); + } +#elif defined __AVX2__ && defined __FMA__ + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2)); + } +#endif +#ifdef GGML_GELU_FP16 + uint16_t t; + for (; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); + } + } +#else +#if defined __ARM_NEON + float32x4_t c1 = vdupq_n_f32(GELU_COEF_A); + float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, ggml_v_gelu(vld1q_f32(x + i), c1, c2)); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +#endif +} + static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; @@ -2968,6 +3113,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "SOFTCAP", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -2995,7 +3141,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3056,6 +3202,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "k2*tanh(k1*x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -3083,7 +3230,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5742,6 +5889,50 @@ struct ggml_tensor * ggml_scale_inplace( return ggml_scale_impl(ctx, a, s, true); } +// ggml_softcap + +static struct ggml_tensor * ggml_softcap_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after, + bool inplace) { + GGML_ASSERT(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + float params[2] = {s_before, s_after}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SOFTCAP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_softcap( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after) { + return ggml_softcap_impl(ctx, a, s_before, s_after, false); +} + +struct ggml_tensor * ggml_softcap_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after) { + return ggml_softcap_impl(ctx, a, s_before, s_after, true); +} + // ggml_set static struct ggml_tensor * ggml_set_impl( @@ -13324,6 +13515,71 @@ static void ggml_compute_forward_scale( } } +// ggml_compute_forward_softcap + +static void ggml_compute_forward_softcap_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + // scale factor + float val[2]; + memcpy(val, dst->op_params, sizeof(val)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + // TODO: better implementation + float * row = (float *) ((char *) dst->data + i1*nb1); + ggml_vec_softcap_f32(nc, row, val[0], val[1]); + //ggml_vec_scale_f32(nc, row, val[0]); + //ggml_vec_tanh_f32(nc, row, row); + //ggml_vec_scale_f32(nc, row, val[1]); + } +} + +static void ggml_compute_forward_softcap( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_softcap_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_set static void ggml_compute_forward_set_f32( @@ -17175,6 +17431,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_scale(params, tensor); } break; + case GGML_OP_SOFTCAP: + { + ggml_compute_forward_softcap(params, tensor); + } break; case GGML_OP_SET: { ggml_compute_forward_set(params, tensor); @@ -17917,6 +18177,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -18928,6 +19192,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = 1; //TODO } break; case GGML_OP_SCALE: + case GGML_OP_SOFTCAP: case GGML_OP_SOFT_MAX: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); diff --git a/src/llama.cpp b/src/llama.cpp index 17253f7a..4aee41a4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8317,14 +8317,17 @@ static struct ggml_tensor * llm_build_kqv( //try from phi2 //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); + //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + //kq = ggml_scale(ctx, kq, 30); + + kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); } if (hparams.attn_soft_cap) { - kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); - kq = ggml_tanh(ctx, kq); - kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); + kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + //kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); + //kq = ggml_tanh(ctx, kq); + //kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); } kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); @@ -11935,9 +11938,10 @@ struct llm_build_context { cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); // final logit soft-capping - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); - cur = ggml_tanh(ctx0, cur); - cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping); + //cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + //cur = ggml_tanh(ctx0, cur); + //cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); cb(cur, "result_output", -1); |