summaryrefslogtreecommitdiff
path: root/tests/test-grad0.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-grad0.cpp')
-rw-r--r--tests/test-grad0.cpp54
1 files changed, 37 insertions, 17 deletions
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp
index 75a698d7..468cde66 100644
--- a/tests/test-grad0.cpp
+++ b/tests/test-grad0.cpp
@@ -275,14 +275,14 @@ static bool check_gradient(
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
- const float f0 = ggml_get_f32_1d(f, 0);
+ const double f0 = ggml_get_f32_1d(f, 0);
ggml_set_f32_1d(x[i], k, xm);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
- const float f1 = ggml_get_f32_1d(f, 0);
- const float g0 = (f0 - f1)/(2.0f*eps);
+ const double f1 = ggml_get_f32_1d(f, 0);
+ const double g0 = (f0 - f1)/(2.0*(double) eps);
ggml_set_f32_1d(x[i], k, x0);
@@ -292,10 +292,10 @@ static bool check_gradient(
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
- const float g1 = ggml_get_f32_1d(x[i]->grad, k);
+ const double g1 = ggml_get_f32_1d(x[i]->grad, k);
- const float error_abs = fabsf(g0 - g1);
- const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
+ const double error_abs = fabs(g0 - g1);
+ const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0;
if (error_abs > max_error_abs || error_rel > max_error_rel) {
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
@@ -531,7 +531,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
- check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+ check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f);
}
}
@@ -1345,9 +1345,18 @@ int main(int argc, const char ** argv) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
-
- check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+ float eps = 1e-6f;
+ // dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work
+ // instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0)
+ struct ggml_tensor * f = ggml_sum(ctx0,
+ ggml_log(ctx0,
+ ggml_add1(ctx0,
+ ggml_scale(ctx0,
+ ggml_soft_max(ctx0, x[0]),
+ ggml_new_f32(ctx0, 1.0f - eps)),
+ ggml_new_f32(ctx0, eps))));
+
+ check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
}
}
@@ -1358,15 +1367,26 @@ int main(int argc, const char ** argv) {
int64_t ne2[4];
get_random_dims(ne2, 4);
- for (int ndims = 1; ndims <= 3; ++ndims) {
- x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f);
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
+ // the second argument to cross_entropy_loss must sum up to 1 for each row
+ int nr = ggml_nrows(x[1]);
+ int nc = ggml_nelements(x[1]) / nr;
+ for (int ir = 0; ir < nr; ++ir) {
+ float sum = 0;
+ for (int ic = 0; ic < nc; ++ic) {
+ sum += ((float *) x[1]->data)[ic + ir*nc];
+ }
+ for (int ic = 0; ic < nc; ++ic) {
+ ((float *) x[1]->data)[ic + ir*nc] /= sum;
+ }
+ }
ggml_set_param(ctx0, x[0]);
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
+ struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
- check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY);
- // finite differences regularly fails!
+ check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY);
}
}
@@ -1473,7 +1493,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
- check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+ check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
}
}
}
@@ -1514,7 +1534,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
- check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+ check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
}
}
}