summaryrefslogtreecommitdiff
path: root/tests/test-backend-ops.cpp
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-27 17:40:59 +0300
committerGitHub <noreply@github.com>2024-08-27 17:40:59 +0300
commitc7e99c88a2de7489ba2a1539b1a9025912010b70 (patch)
tree9976409b1e8fac1fc7486f2c5da05a33b8e229b5 /tests/test-backend-ops.cpp
parentbd99ed7d0afd2b12c0f5ff5c17b58486396dfe7e (diff)
Faster Gemma2 (#27)
* soft_cap_max: initial CPU version of fused softcap + soft_max With this vanilla CPU implementation I'm already getting a ~3% speedup for Gemma-2-9b and a prompt of 8192 tokens. * soft_cap_max: WIP - something is wrong with CUDA * soft_cap_max: looks good on CPU and CUDA * Add softcap to flash attention Just CPU and CUDA for now (but, as we know, flash attention on the CPU is useless in llama.cpp). On CUDA this improves PP performance quite a bit, especially for long contexts. E.g., for PP-16384, I now get 3777 t/s. Without this change, one cannot use FA, and one gets 2300 t/s (after fusing softcap and softmax), or 2000 t/s without the fused softcap+softmax. In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before PR-8542 (where Johannes Gaessler has also added softcap to FA), and PP-16384 = 3097 t/s after this PR. * soft_cap_max: Metal * Flash attention with softcap: Metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r--tests/test-backend-ops.cpp22
1 files changed, 13 insertions, 9 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index a2182c1b..f51ec5b8 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case {
const bool mask; // use mask
const float max_bias; // ALiBi
+ const float softcap; // Gemma-2
const ggml_type type_KV;
std::string vars() override {
- return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
+ return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, softcap, type_KV);
}
double max_nmse_err() override {
return 5e-4;
}
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), softcap(softcap), type_KV(type_KV) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case {
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, softcap);
return out;
}
};
@@ -2434,11 +2435,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
- for (int nh : { 32, }) {
- for (int kv : { 512, 1024, }) {
- for (int nb : { 1, 2, 4, 8, }) {
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
+ for (float softcap : {0.0f, 10.0f}) {
+ if (hs != 128 && softcap != 0.0f) continue;
+ for (int nh : { 32, }) {
+ for (int kv : { 512, 1024, }) {
+ for (int nb : { 1, 2, 4, 8, }) {
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, softcap, type_KV));
+ }
}
}
}