summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-20 17:15:47 +0300
committerGitHub <noreply@github.com>2024-08-20 17:15:47 +0300
commitd259a50ca6fd3a0821abe6a16b73c0b19c5b4651 (patch)
tree4f83bbbbbbd9323192d8c0bceb51de5b0fb620c2
parenta325745000114a43c1546323f91720db503ed0a9 (diff)
Fused soft cap and SIMD-ified GeLU (#9)
* Softcap: WIP Fuses scale + tanh + scale as used for softcaping in some models. Just CPU for now. ~1.4% for PP-512 on Gemma2-9b, no effect on TG. Somewhat surprisingly the improvement does not increase as I go to longer contexts. Gemma2 does softcap on K*Q, which grows quadratically with context length, so I would have thought the benefit from fusing scale, tanh, scale would increase. But no, no luck. * softcap: CUDA * softcap: CUDA ~1% speedup for Gemma2-9b * softcap: Metal and NEON About 1% speedup. * Simdified gelu Gives ~1% speedup for Gemma2-9b prompt processing on AVX512/AVX2. It looks like the gelu operation is memory bound on my CPU's after SIMD-ifying it. By not using the 128 kb gelu lookup table we gain a small advantage. On the M2-Max the lookup table is slightly faster than the SIMD version, so left the lookup table for ARM_NEON. * softcap, tanh: avoid NaNs for large arguments (AVX2, AVX512) Not that I have encountered this in practice, but just to be sure. This does it for AVX512 and AVX2, still need a guard for ARM_NEON. * llama-bench: add ability to turn off warmup runs So we don't need to wait forever on, e.g., benchmarks involving long contexts. * softcap, tanh: avoid NaNs for large arguments (NEON) --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/llama-bench/llama-bench.cpp24
-rw-r--r--ggml/include/ggml.h14
-rw-r--r--ggml/src/ggml-cuda.cu5
-rw-r--r--ggml/src/ggml-cuda/softcap.cu32
-rw-r--r--ggml/src/ggml-cuda/softcap.cuh5
-rw-r--r--ggml/src/ggml-metal.m32
-rw-r--r--ggml/src/ggml-metal.metal18
-rw-r--r--ggml/src/ggml.c337
-rw-r--r--src/llama.cpp20
9 files changed, 437 insertions, 50 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 42918bfc..813d7bae 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -237,6 +237,7 @@ struct cmd_params {
ggml_numa_strategy numa;
int reps;
bool verbose;
+ bool warmup;
output_formats output_format;
output_formats output_format_stderr;
};
@@ -263,6 +264,7 @@ static const cmd_params cmd_params_defaults = {
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
/* reps */ 5,
/* verbose */ false,
+ /* warmup */ true,
/* output_format */ MARKDOWN,
/* output_format_stderr */ NONE,
};
@@ -295,6 +297,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format));
printf(" -oe, --output-err <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr));
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
+ printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
}
@@ -338,6 +341,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
params.output_format_stderr = cmd_params_defaults.output_format_stderr;
params.reps = cmd_params_defaults.reps;
params.numa = cmd_params_defaults.numa;
+ params.warmup = cmd_params_defaults.warmup;
for (int i = 1; i < argc; i++) {
arg = argv[i];
@@ -555,6 +559,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = !output_format_from_str(argv[i], params.output_format_stderr);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
+ } else if (arg == "-w" || arg == "--warmup") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.warmup = std::stoi(argv[i]);
} else {
invalid_param = true;
break;
@@ -1429,12 +1439,14 @@ int main(int argc, char ** argv) {
llama_kv_cache_clear(ctx);
// warmup run
- if (t.n_prompt > 0) {
- //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
- test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
- }
- if (t.n_gen > 0) {
- test_gen(ctx, 1, 0, t.n_threads);
+ if (params.warmup) {
+ if (t.n_prompt > 0) {
+ //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
+ test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
+ }
+ if (t.n_gen > 0) {
+ test_gen(ctx, 1, 0, t.n_threads);
+ }
}
for (int i = 0; i < params.reps; i++) {
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 026993db..17d3cb1a 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -514,6 +514,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
+ GGML_OP_SOFTCAP,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
@@ -1223,6 +1224,19 @@ extern "C" {
struct ggml_tensor * a,
float s);
+ GGML_API struct ggml_tensor * ggml_softcap(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s_before,
+ float s_after);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_softcap_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s_before,
+ float s_after);
+
// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
struct ggml_context * ctx,
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index f594cd26..73ab0b73 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -24,6 +24,7 @@
#include "ggml-cuda/quantize.cuh"
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softcap.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/tsembd.cuh"
@@ -2261,6 +2262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SCALE:
ggml_cuda_op_scale(ctx, dst);
break;
+ case GGML_OP_SOFTCAP:
+ ggml_cuda_op_softcap(ctx, dst);
+ break;
case GGML_OP_SQR:
ggml_cuda_op_sqr(ctx, dst);
break;
@@ -2865,6 +2869,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
+ case GGML_OP_SOFTCAP:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_CLAMP:
diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu
new file mode 100644
index 00000000..499025d1
--- /dev/null
+++ b/ggml/src/ggml-cuda/softcap.cu
@@ -0,0 +1,32 @@
+#include "softcap.cuh"
+
+static __global__ void softcap_f32(const float * x, float * dst, float s_before, float s_after, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ float xi = s_before*x[i];
+ dst[i] = s_after * tanh(xi);
+}
+
+static void softcap_f32_cuda(const float * x, float * dst, float s_before, float s_after, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
+ softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, s_before, s_after, k);
+}
+
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float scales[2];
+ memcpy(scales, dst->op_params, sizeof(scales));
+
+ softcap_f32_cuda(src0_d, dst_d, scales[0], scales[1], ggml_nelements(src0), stream);
+}
diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh
new file mode 100644
index 00000000..2b875bfb
--- /dev/null
+++ b/ggml/src/ggml-cuda/softcap.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SOFTCAP_BLOCK_SIZE 256
+
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 292f9ac7..1e940c5b 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -51,6 +51,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
GGML_METAL_KERNEL_TYPE_SCALE,
GGML_METAL_KERNEL_TYPE_SCALE_4,
+ GGML_METAL_KERNEL_TYPE_SOFTCAP,
+ GGML_METAL_KERNEL_TYPE_SOFTCAP_4,
GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
@@ -554,6 +556,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP, softcap, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP_4, softcap_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
@@ -867,6 +871,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
return true;
+ case GGML_OP_SOFTCAP:
+ return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
@@ -1413,6 +1419,32 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_OP_SOFTCAP:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ float scales[2];
+ memcpy(scales, dst->op_params, sizeof(scales));
+
+ int64_t n = ggml_nelements(dst);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2];
+ [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_OP_CLAMP:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 904639a5..2a0e84a6 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -289,6 +289,24 @@ kernel void kernel_scale_4(
dst[tpig] = src0[tpig] * scale;
}
+kernel void kernel_softcap(
+ device const float * src0,
+ device float * dst,
+ constant float & s_before,
+ constant float & s_after,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before);
+}
+
+kernel void kernel_softcap_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant float & s_before,
+ constant float & s_after,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before);
+}
+
kernel void kernel_clamp(
device const float * src0,
device float * dst,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index e7f1ae61..9b877bab 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2436,44 +2436,14 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
+static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
inline static float ggml_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
-inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- const uint16_t * i16 = (const uint16_t *) x;
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_table_gelu_f16[i16[i]];
- }
-}
-
-#ifdef GGML_GELU_FP16
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- if (x[i] <= -10.0f) {
- y[i] = 0.0f;
- } else if (x[i] >= 10.0f) {
- y[i] = x[i];
- } else {
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
- }
- }
-}
-#else
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_f32(x[i]);
- }
-}
-#endif
-
inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}
@@ -2555,7 +2525,33 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
const float32x4_t exp_two_x = ggml_v_expf(two_x);
- return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+ const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
+ const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+ return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
+ //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+}
+
+inline static float32x4_t ggml_v_softcap(float32x4_t x, float32x4_t s_before, float32x4_t s_after) {
+ return vmulq_f32(s_after, ggml_v_tanh(vmulq_f32(x, s_before)));
+ //const float32x4_t one = vdupq_n_f32(1.0f);
+ //const float32x4_t two_x = vmulq_f32(x, s_before);
+ //const float32x4_t exp_two_x = ggml_v_expf(two_x);
+ //const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+ //return vmulq_f32(th, s_after);
+}
+
+
+// Slower than lookup on my M2-Max
+inline static float32x4_t ggml_v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ //float32x4_t arg = vaddq_f32(one, vmulq_f32(vmulq_f32(x, x), c1));
+ float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
+ arg = vmulq_f32(arg, vmulq_f32(x, c2));
+ float32x4_t exp_arg = ggml_v_expf(arg);
+ float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
+ uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
+ return vbslq_f32(mask, x, gelu);
+ //return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(x), mask), vbicq_u32(vreinterpretq_u32_f32(gelu), mask)));
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -2604,7 +2600,27 @@ inline static __m512 ggml_v_silu(__m512 x) {
inline static __m512 ggml_v_tanh(__m512 x) {
const __m512 one = _mm512_set1_ps(1.0f);
const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
- return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
+ const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
+ const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
+ return _mm512_mask_blend_ps(mask, res, one);
+}
+
+inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before));
+ const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
+ return _mm512_mul_ps(th, s_after);
+}
+
+inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
+ //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
+ arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
+ const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
+ const __m512 exp_arg = ggml_v_expf(arg);
+ const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
+ return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
}
#elif defined(__AVX2__) && defined(__FMA__)
@@ -2665,7 +2681,27 @@ inline static __m256 ggml_v_silu(__m256 x) {
inline static __m256 ggml_v_tanh(__m256 x) {
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
- return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
+ const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
+ const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
+ return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
+}
+
+inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) {
+ return _mm256_mul_ps(_mm256_set1_ps(s_after), ggml_v_tanh(_mm256_mul_ps(x, _mm256_set1_ps(s_before))));
+ //const __m256 one = _mm256_set1_ps(1.0f);
+ //const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before)));
+ //const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
+ //return _mm256_mul_ps(th, _mm256_set1_ps(s_after));
+}
+
+inline static __m256 ggml_v_gelu(__m256 x, __m256 c1, __m256 c2) {
+ const __m256 one = _mm256_set1_ps(1.0f);
+ const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
+ __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1));
+ arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2));
+ __m256 exp_arg = ggml_v_expf(arg);
+ __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one)));
+ return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu));
}
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
@@ -2728,6 +2764,13 @@ inline static __m128 ggml_v_tanh(__m128 x) {
return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one));
}
+inline static __m128 ggml_v_softcap(__m128 x, float s_before, float s_after) {
+ const __m128 one = _mm_set1_ps(1.0f);
+ const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f*s_before)));
+ const __m128 th = _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one));
+ return _mm_mul_ps(th, _mm_set1_ps(s_after));
+}
+
#endif // __ARM_NEON / __AVX2__ / __SSE2__
static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
@@ -2778,6 +2821,108 @@ static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
}
}
+static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ __m512 vs_before = _mm512_set1_ps(2.f*s_before);
+ __m512 vs_after = _mm512_set1_ps(s_after);
+ //for (; i + 63 < n; i += 64) {
+ // __m512 x1 = _mm512_loadu_ps(x + i);
+ // __m512 x2 = _mm512_loadu_ps(x + i + 16);
+ // __m512 x3 = _mm512_loadu_ps(x + i + 32);
+ // __m512 x4 = _mm512_loadu_ps(x + i + 48);
+ // _mm512_storeu_ps(x + i + 0, ggml_v_softcap(x1, vs_before, vs_after));
+ // _mm512_storeu_ps(x + i + 16, ggml_v_softcap(x2, vs_before, vs_after));
+ // _mm512_storeu_ps(x + i + 32, ggml_v_softcap(x3, vs_before, vs_after));
+ // _mm512_storeu_ps(x + i + 48, ggml_v_softcap(x4, vs_before, vs_after));
+ //}
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(x + i, ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after));
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ float32x4_t vs_before = vdupq_n_f32(s_before);
+ float32x4_t vs_after = vdupq_n_f32(s_after);
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));
+ }
+#endif
+ for (; i < n; ++i) {
+ x[i] = s_after*tanhf(x[i]*s_before);
+ }
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+ const uint16_t * i16 = (const uint16_t *) x;
+ for (int i = 0; i < n; ++i) {
+ y[i] = ggml_table_gelu_f16[i16[i]];
+ }
+}
+
+//
+// On my AVX512 (Ryzen-7950X) and AVX2 (Ryzen-5975WX) computing gelu directly
+// via SIMD instructions is faster than the fp16-based lookup table.
+// On my M2-Max CPU the lookup table is slightly faster than the SIMD version,
+// hence we use the SIMD version only if GGML_GELU_FP16 is not defined.
+// We do not run into numerical issues for large or small arguments because
+// 0.5f * (1 + tanhf(arg))
+// is computed as
+// exp(2*arg)/(exp(2*arg) + 1)
+// The ggml_v_expf functions flushes to zero for large enough negative
+// arguments, so the above becomes zero. ggml_v_expf returns INFINITY
+// for large positive arguments, so we would get a NaN if we did nothing. But in the
+// ggml_v_gelu SIMD implementations we override the gelu result with the
+// input argument when the argument is greater than 10, so it is all good.
+//
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
+ __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(y + i, ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2));
+ }
+#elif defined __AVX2__ && defined __FMA__
+ __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
+ __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(y + i, ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2));
+ }
+#endif
+#ifdef GGML_GELU_FP16
+ uint16_t t;
+ for (; i < n; ++i) {
+ if (x[i] <= -10.0f) {
+ y[i] = 0.0f;
+ } else if (x[i] >= 10.0f) {
+ y[i] = x[i];
+ } else {
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+ memcpy(&t, &fp16, sizeof(uint16_t));
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+ }
+ }
+#else
+#if defined __ARM_NEON
+ float32x4_t c1 = vdupq_n_f32(GELU_COEF_A);
+ float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI);
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(y + i, ggml_v_gelu(vld1q_f32(x + i), c1, c2));
+ }
+#endif
+ for (; i < n; ++i) {
+ y[i] = ggml_gelu_f32(x[i]);
+ }
+#endif
+}
+
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
ggml_float sum = 0;
@@ -2968,6 +3113,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",
+ "SOFTCAP",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
@@ -2995,7 +3141,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3056,6 +3202,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",
+ "k2*tanh(k1*x)",
"flash_attn_ext(x)",
"flash_attn_back(x)",
@@ -3083,7 +3230,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5742,6 +5889,50 @@ struct ggml_tensor * ggml_scale_inplace(
return ggml_scale_impl(ctx, a, s, true);
}
+// ggml_softcap
+
+static struct ggml_tensor * ggml_softcap_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s_before,
+ float s_after,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_padded_1d(a));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ float params[2] = {s_before, s_after};
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_SOFTCAP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_softcap(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s_before,
+ float s_after) {
+ return ggml_softcap_impl(ctx, a, s_before, s_after, false);
+}
+
+struct ggml_tensor * ggml_softcap_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s_before,
+ float s_after) {
+ return ggml_softcap_impl(ctx, a, s_before, s_after, true);
+}
+
// ggml_set
static struct ggml_tensor * ggml_set_impl(
@@ -13324,6 +13515,71 @@ static void ggml_compute_forward_scale(
}
}
+// ggml_compute_forward_softcap
+
+static void ggml_compute_forward_softcap_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ // scale factor
+ float val[2];
+ memcpy(val, dst->op_params, sizeof(val));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ const size_t nb01 = src0->nb[1];
+
+ const size_t nb1 = dst->nb[1];
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ if (dst->data != src0->data) {
+ // src0 is same shape as dst => same indices
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+ }
+ // TODO: better implementation
+ float * row = (float *) ((char *) dst->data + i1*nb1);
+ ggml_vec_softcap_f32(nc, row, val[0], val[1]);
+ //ggml_vec_scale_f32(nc, row, val[0]);
+ //ggml_vec_tanh_f32(nc, row, row);
+ //ggml_vec_scale_f32(nc, row, val[1]);
+ }
+}
+
+static void ggml_compute_forward_softcap(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_softcap_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_set
static void ggml_compute_forward_set_f32(
@@ -17175,6 +17431,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_scale(params, tensor);
} break;
+ case GGML_OP_SOFTCAP:
+ {
+ ggml_compute_forward_softcap(params, tensor);
+ } break;
case GGML_OP_SET:
{
ggml_compute_forward_set(params, tensor);
@@ -17917,6 +18177,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
+ case GGML_OP_SOFTCAP:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -18928,6 +19192,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = 1; //TODO
} break;
case GGML_OP_SCALE:
+ case GGML_OP_SOFTCAP:
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
diff --git a/src/llama.cpp b/src/llama.cpp
index 17253f7a..4aee41a4 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8317,14 +8317,17 @@ static struct ggml_tensor * llm_build_kqv(
//try from phi2
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
- kq = ggml_scale(ctx, kq, 30);
+ //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
+ //kq = ggml_scale(ctx, kq, 30);
+
+ kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
}
if (hparams.attn_soft_cap) {
- kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
- kq = ggml_tanh(ctx, kq);
- kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
+ kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ //kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
+ //kq = ggml_tanh(ctx, kq);
+ //kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
@@ -11935,9 +11938,10 @@ struct llm_build_context {
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
// final logit soft-capping
- cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
- cur = ggml_tanh(ctx0, cur);
- cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+ cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
+ //cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+ //cur = ggml_tanh(ctx0, cur);
+ //cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
cb(cur, "result_output", -1);