diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-05-11 10:32:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-11 10:32:41 +0300 |
commit | 9cb317f77e53067f7a138cc89ef7657148eae8e6 (patch) | |
tree | 3ba1d2d80d1d7c8b4ab01f6396a3febaae26e91b /tests | |
parent | e849648888a11de13aaaa4cb2eda3f5a9c7b444d (diff) |
ggml : full ALiBi support (#7192)
* ggml : full ALiBi support
* ggml : update ggml_soft_max_ext() CUDA, SYCL
* ggml : ggml_flash_attn_ext() support ALiBi (CPU)
* ggml : ggml_flash_attn_ext() support ALiBi (Metal)
* ggml : fix warning
* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)
ggml-ci
* ggml : fix assert message
* vulkan : add dev notes
* ggml : require mask when using ALiBi
ggml-ci
* convert : fix convert for refact models
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test-backend-ops.cpp | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0d66de5d..731788b9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1111,11 +1111,7 @@ struct test_soft_max : public test_case { if (this->mask) { mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); } - ggml_tensor * pos = nullptr; - if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); - } - ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); + ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); return out; } }; @@ -1490,23 +1486,25 @@ struct test_flash_attn_ext : public test_case { const int64_t kv; // kv size const int64_t nb; // batch size + const float max_bias; // ALiBi + std::string vars() override { - return VARS_TO_STR4(hs, nh, kv, nb); + return VARS_TO_STR5(hs, nh, kv, nb, max_bias); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f) + : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias); return out; } }; @@ -1611,7 +1609,7 @@ public: struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f); // split cached v into n_head heads struct ggml_tensor * v = @@ -2128,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #endif for (bool mask : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { + if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { @@ -2141,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { @@ -2180,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #else for (int hs : { 64, 80, 128, 256, }) { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + for (float max_bias : {0.0f, 8.0f}) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias)); + } } } } |