summaryrefslogtreecommitdiff
path: root/examples/train-text-from-scratch
diff options
context:
space:
mode:
Diffstat (limited to 'examples/train-text-from-scratch')
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp3
1 files changed, 2 insertions, 1 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index 45bdfa8f..e2f85c68 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -341,7 +341,8 @@ static struct ggml_tensor * llama_build_train_graphs(
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
struct ggml_tensor * t16;
if (enable_flash_attn) {
- t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
+ GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported");
+ //t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
} else {
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);