summaryrefslogtreecommitdiff
path: root/tests/test-grad0.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-grad0.cpp')
-rw-r--r--tests/test-grad0.cpp126
1 files changed, 43 insertions, 83 deletions
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp
index 8ff76c89..21ca43be 100644
--- a/tests/test-grad0.cpp
+++ b/tests/test-grad0.cpp
@@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) {
}
// flash_attn f32
- {
- srand(seed);
- const int nargs = 3;
-
- int64_t ne2[4];
-
- get_random_dims(ne2, 4);
- int64_t D = ne2[0];
- int64_t N = ne2[1];
- int64_t M = ne2[2] + N;
- int64_t B = ne2[3];
-
- for (int masked = 0; masked <= 1; ++masked) {
- for (int ndims = 2; ndims <= 4; ++ndims) {
- int max_nrep = (ndims >= 3) ? 2 : 1;
- for (int nrep = 1; nrep < max_nrep; ++nrep) {
- int64_t neq[4] = { D, N, B*nrep, ne[3] };
- int64_t nek[4] = { D, M, B, ne[3] };
- int64_t nev[4] = { M, D, B, ne[3] };
- if (ndims == 2) {
- neq[2] = 1; neq[3] = 1;
- nek[2] = 1; nek[3] = 1;
- nev[2] = 1; nev[3] = 1;
- } else if (ndims == 3) {
- neq[3] = 1;
- nek[3] = 1;
- nev[3] = 1;
- }
- x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
- x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
- x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
- ggml_set_param(ctx0, x[0]);
- ggml_set_param(ctx0, x[1]);
- ggml_set_param(ctx0, x[2]);
-
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
-
- check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
- }
- }
- }
- }
-
- // flash_attn f16, not yet fully implemented
- if(0)
- {
- srand(seed);
- const int nargs = 3;
-
- int64_t ne2[4];
-
- get_random_dims(ne2, 4);
- int64_t D = ne2[0];
- int64_t N = ne2[1];
- int64_t M = ne2[2] + N;
- int64_t B = ne2[3];
-
- for (int masked = 0; masked <= 1; ++masked) {
- for (int ndims = 2; ndims <= 4; ++ndims) {
- int64_t neq[4] = { D, N, B, ne[3] };
- int64_t nek[4] = { D, M, B, ne[3] };
- int64_t nev[4] = { M, D, B, ne[3] };
- if (ndims == 2) {
- neq[2] = 1; neq[3] = 1;
- nek[2] = 1; nek[3] = 1;
- nev[2] = 1; nev[3] = 1;
- } else if (ndims == 3) {
- neq[3] = 1;
- nek[3] = 1;
- nev[3] = 1;
- }
- x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
- x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
- x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
- ggml_set_param(ctx0, x[0]);
- ggml_set_param(ctx0, x[1]);
- ggml_set_param(ctx0, x[2]);
-
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+ // TODO: adapt to ggml_flash_attn_ext() changes
+ //{
+ // srand(seed);
+ // const int nargs = 3;
+
+ // int64_t ne2[4];
+
+ // get_random_dims(ne2, 4);
+ // int64_t D = ne2[0];
+ // int64_t N = ne2[1];
+ // int64_t M = ne2[2] + N;
+ // int64_t B = ne2[3];
+
+ // for (int masked = 0; masked <= 1; ++masked) {
+ // for (int ndims = 2; ndims <= 4; ++ndims) {
+ // int max_nrep = (ndims >= 3) ? 2 : 1;
+ // for (int nrep = 1; nrep < max_nrep; ++nrep) {
+ // int64_t neq[4] = { D, N, B*nrep, ne[3] };
+ // int64_t nek[4] = { D, M, B, ne[3] };
+ // int64_t nev[4] = { M, D, B, ne[3] };
+ // if (ndims == 2) {
+ // neq[2] = 1; neq[3] = 1;
+ // nek[2] = 1; nek[3] = 1;
+ // nev[2] = 1; nev[3] = 1;
+ // } else if (ndims == 3) {
+ // neq[3] = 1;
+ // nek[3] = 1;
+ // nev[3] = 1;
+ // }
+ // x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+ // x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+ // x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+ // ggml_set_param(ctx0, x[0]);
+ // ggml_set_param(ctx0, x[1]);
+ // ggml_set_param(ctx0, x[2]);
+
+ // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+ // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+ // }
+ // }
+ // }
+ //}
- check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
- }
- }
- }
ggml_free(ctx0);
}