diff options
Diffstat (limited to 'tests/test-grad0.cpp')
-rw-r--r-- | tests/test-grad0.cpp | 126 |
1 files changed, 43 insertions, 83 deletions
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 8ff76c89..21ca43be 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) { } // flash_attn f32 - { - srand(seed); - const int nargs = 3; - - int64_t ne2[4]; - - get_random_dims(ne2, 4); - int64_t D = ne2[0]; - int64_t N = ne2[1]; - int64_t M = ne2[2] + N; - int64_t B = ne2[3]; - - for (int masked = 0; masked <= 1; ++masked) { - for (int ndims = 2; ndims <= 4; ++ndims) { - int max_nrep = (ndims >= 3) ? 2 : 1; - for (int nrep = 1; nrep < max_nrep; ++nrep) { - int64_t neq[4] = { D, N, B*nrep, ne[3] }; - int64_t nek[4] = { D, M, B, ne[3] }; - int64_t nev[4] = { M, D, B, ne[3] }; - if (ndims == 2) { - neq[2] = 1; neq[3] = 1; - nek[2] = 1; nek[3] = 1; - nev[2] = 1; nev[3] = 1; - } else if (ndims == 3) { - neq[3] = 1; - nek[3] = 1; - nev[3] = 1; - } - x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f); - x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f); - x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f); - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); - ggml_set_param(ctx0, x[2]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - - check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); - } - } - } - } - - // flash_attn f16, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 3; - - int64_t ne2[4]; - - get_random_dims(ne2, 4); - int64_t D = ne2[0]; - int64_t N = ne2[1]; - int64_t M = ne2[2] + N; - int64_t B = ne2[3]; - - for (int masked = 0; masked <= 1; ++masked) { - for (int ndims = 2; ndims <= 4; ++ndims) { - int64_t neq[4] = { D, N, B, ne[3] }; - int64_t nek[4] = { D, M, B, ne[3] }; - int64_t nev[4] = { M, D, B, ne[3] }; - if (ndims == 2) { - neq[2] = 1; neq[3] = 1; - nek[2] = 1; nek[3] = 1; - nev[2] = 1; nev[3] = 1; - } else if (ndims == 3) { - neq[3] = 1; - nek[3] = 1; - nev[3] = 1; - } - x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f); - x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f); - x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f); - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); - ggml_set_param(ctx0, x[2]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); + // TODO: adapt to ggml_flash_attn_ext() changes + //{ + // srand(seed); + // const int nargs = 3; + + // int64_t ne2[4]; + + // get_random_dims(ne2, 4); + // int64_t D = ne2[0]; + // int64_t N = ne2[1]; + // int64_t M = ne2[2] + N; + // int64_t B = ne2[3]; + + // for (int masked = 0; masked <= 1; ++masked) { + // for (int ndims = 2; ndims <= 4; ++ndims) { + // int max_nrep = (ndims >= 3) ? 2 : 1; + // for (int nrep = 1; nrep < max_nrep; ++nrep) { + // int64_t neq[4] = { D, N, B*nrep, ne[3] }; + // int64_t nek[4] = { D, M, B, ne[3] }; + // int64_t nev[4] = { M, D, B, ne[3] }; + // if (ndims == 2) { + // neq[2] = 1; neq[3] = 1; + // nek[2] = 1; nek[3] = 1; + // nev[2] = 1; nev[3] = 1; + // } else if (ndims == 3) { + // neq[3] = 1; + // nek[3] = 1; + // nev[3] = 1; + // } + // x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f); + // x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f); + // x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f); + // ggml_set_param(ctx0, x[0]); + // ggml_set_param(ctx0, x[1]); + // ggml_set_param(ctx0, x[2]); + + // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); + + // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); + // } + // } + // } + //} - check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); - } - } - } ggml_free(ctx0); } |