summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml.c63
1 files changed, 57 insertions, 6 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 95a1fc7d..c3cda4c4 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2265,7 +2265,7 @@ inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
-inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
+//inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
@@ -2389,6 +2389,13 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
return vdivq_f32(x, one_plus_exp_neg_x);
}
+inline static float32x4_t ggml_v_tanh(float32x4_t x) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
+ const float32x4_t exp_two_x = ggml_v_expf(two_x);
+ return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+}
+
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
// adapted from arm limited optimized routine
@@ -2432,6 +2439,12 @@ inline static __m512 ggml_v_silu(__m512 x) {
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
+inline static __m512 ggml_v_tanh(__m512 x) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
+ return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
+}
+
#elif defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
@@ -2487,6 +2500,12 @@ inline static __m256 ggml_v_silu(__m256 x) {
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
+inline static __m256 ggml_v_tanh(__m256 x) {
+ const __m256 one = _mm256_set1_ps(1.0f);
+ const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
+ return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
+}
+
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
#if defined(__FMA__)
@@ -2541,6 +2560,12 @@ inline static __m128 ggml_v_silu(__m128 x) {
return _mm_div_ps(x, one_plus_exp_neg_x);
}
+inline static __m128 ggml_v_tanh(__m128 x) {
+ const __m128 one = _mm_set1_ps(1.0f);
+ const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f)));
+ return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one));
+}
+
#endif // __ARM_NEON / __AVX2__ / __SSE2__
static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
@@ -2567,6 +2592,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}
+static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(y + i, ggml_v_tanh(_mm512_loadu_ps(x + i)));
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(y + i, ggml_v_tanh(_mm256_loadu_ps(x + i)));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(y + i, ggml_v_tanh(_mm_loadu_ps(x + i)));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(y + i, ggml_v_tanh(vld1q_f32(x + i)));
+ }
+#endif
+ for (; i < n; ++i) {
+ y[i] = tanhf(x[i]);
+ }
+}
+
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
ggml_float sum = 0;
@@ -11204,9 +11253,8 @@ static void ggml_compute_forward_tanh_f32(
const struct ggml_tensor * src0 = dst->src[0];
- if (params->ith != 0) {
- return;
- }
+ const int ith = params->ith;
+ const int nth = params->nth;
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
@@ -11215,7 +11263,7 @@ static void ggml_compute_forward_tanh_f32(
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
- for (int i = 0; i < n; i++) {
+ for (int i = ith; i < n; i += nth) {
ggml_vec_tanh_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
@@ -18590,7 +18638,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
- case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
@@ -18606,6 +18653,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = n_threads;
} break;
+ case GGML_UNARY_OP_TANH:
+ {
+ n_tasks = MIN(ggml_nrows(node), n_threads);
+ } break;
default:
GGML_ASSERT(false);
}