summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c1964
1 files changed, 1334 insertions, 630 deletions
diff --git a/ggml.c b/ggml.c
index 3fcc44bd..ea964bab 100644
--- a/ggml.c
+++ b/ggml.c
@@ -134,6 +134,7 @@ typedef void * thread_ret_t;
#define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 2
+#define GGML_VEC_MAD_UNROLL 32
//
// logging
@@ -3707,6 +3708,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
#endif
}
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
+
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
+
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+ x[i] = (const float *) ((const char *) xv + i*xs);
+ v[i] = (const float *) ((const char *) vv + i*vs);
+ }
+
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+ }
+
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+ }
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ for (int i = np; i < n; ++i) {
+ y[i] += x[k][i]*v[k][0];
+ }
+ }
+#else
+ // scalar
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ for (int i = 0; i < n; ++i) {
+ y[i] += x[k][i]*v[k][0];
+ }
+ }
+#endif
+}
+
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_USE_ACCELERATE)
@@ -4392,10 +4445,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
- return
- (t0->ne[1] == t1->ne[1]) &&
- (t0->ne[2] == t1->ne[2]) &&
- (t0->ne[3] == t1->ne[3]);
+ return (t0->ne[1] == t1->ne[1]) &&
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
+ (t1->ne[3]%t0->ne[3] == 0);
}
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@@ -5065,7 +5117,36 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
return tensor;
}
+void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
+ const int64_t ne2 = tensor->ne[2];
+ const int64_t ne1 = tensor->ne[1];
+ const int64_t ne0 = tensor->ne[0];
+
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
+
+ if (i0) {
+ * i0 = i0_;
+ }
+ if (i1) {
+ * i1 = i1_;
+ }
+ if (i2) {
+ * i2 = i2_;
+ }
+ if (i3) {
+ * i3 = i3_;
+ }
+}
+
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
@@ -5102,6 +5183,12 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
}
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
+ return;
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
@@ -5135,7 +5222,74 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
}
}
+int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ return ((int8_t *) data)[0];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ return ((int16_t *) data)[0];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ return ((int32_t *) data)[0];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ return ((float *) data)[0];
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(data))[0] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
@@ -5172,6 +5326,12 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
}
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
+ return;
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
@@ -5205,6 +5365,68 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
}
}
+float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ return ((int8_t *) data)[0];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ return ((int16_t *) data)[0];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ return ((int32_t *) data)[0];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ return ((float *) data)[0];
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(data))[0] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
void * ggml_get_data(const struct ggml_tensor * tensor) {
return tensor->data;
}
@@ -5347,6 +5569,44 @@ struct ggml_tensor * ggml_add_inplace(
return ggml_add_impl(ctx, a, b, true);
}
+// ggml_add_cast
+
+static struct ggml_tensor * ggml_add_cast_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type) {
+ // TODO: support less-strict constraint
+ // GGML_ASSERT(ggml_can_repeat(b, a));
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
+ GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
+
+ result->op = GGML_OP_ADD;
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type) {
+ return ggml_add_cast_impl(ctx, a, b, type);
+}
+
// ggml_add1
static struct ggml_tensor * ggml_add1_impl(
@@ -5783,7 +6043,6 @@ struct ggml_tensor * ggml_repeat(
result->op = GGML_OP_REPEAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
- result->src[1] = b;
return result;
}
@@ -5811,7 +6070,6 @@ struct ggml_tensor * ggml_repeat_back(
result->op = GGML_OP_REPEAT_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
- result->src[1] = b;
return result;
}
@@ -6186,8 +6444,9 @@ struct ggml_tensor * ggml_out_prod(
is_node = true;
}
- const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
result->op = GGML_OP_OUT_PROD;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6461,7 +6720,7 @@ struct ggml_tensor * ggml_reshape(
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_is_contiguous(a));
- GGML_ASSERT(ggml_is_contiguous(b));
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
bool is_node = false;
@@ -6834,7 +7093,6 @@ struct ggml_tensor * ggml_get_rows_back(
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
- result->src[2] = c;
return result;
}
@@ -7540,27 +7798,30 @@ struct ggml_tensor * ggml_flash_attn_back(
// d shape [D,N,ne2,ne3]
// q shape [D,N,ne2,ne3]
- // k shape [D,M,ne2,ne3]
- // v shape [M,D,ne2,ne3]
+ // k shape [D,M,kvne2,ne3]
+ // v shape [M,D,kvne2,ne3]
- const int64_t D = q->ne[0];
- const int64_t N = q->ne[1];
- const int64_t M = k->ne[1];
- const int64_t ne2 = q->ne[2];
- const int64_t ne3 = q->ne[3];
+ const int64_t D = q->ne[0];
+ const int64_t N = q->ne[1];
+ const int64_t M = k->ne[1];
+ const int64_t ne2 = q->ne[2];
+ const int64_t ne3 = q->ne[3];
+ const int64_t kvne2 = k->ne[2];
GGML_ASSERT(k->ne[0] == D);
GGML_ASSERT(v->ne[0] == M);
GGML_ASSERT(v->ne[1] == D);
GGML_ASSERT(d->ne[0] == D);
GGML_ASSERT(d->ne[1] == N);
- GGML_ASSERT(k->ne[2] == ne2);
+ GGML_ASSERT(k->ne[2] == kvne2);
GGML_ASSERT(k->ne[3] == ne3);
- GGML_ASSERT(v->ne[2] == ne2);
+ GGML_ASSERT(v->ne[2] == kvne2);
GGML_ASSERT(v->ne[3] == ne3);
GGML_ASSERT(d->ne[2] == ne2);
GGML_ASSERT(d->ne[3] == ne3);
+ GGML_ASSERT(ne2 % kvne2 == 0);
+
bool is_node = false;
if (q->grad || k->grad || v->grad) {
@@ -7570,14 +7831,23 @@ struct ggml_tensor * ggml_flash_attn_back(
}
// store gradients of q, k and v as continuous tensors concatenated in result.
- // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
- // gradq->data = result->data
- // gradk->data = result->data + nb0*D*N*ne2*ne3
- // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
- int64_t ne[4] = {D,M+N+M,ne2,ne3};
+ const int64_t elem_q = ggml_nelements(q);
+ const int64_t elem_k = ggml_nelements(k);
+ const int64_t elem_v = ggml_nelements(v);
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+ enum ggml_type result_type = GGML_TYPE_F32;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
+
+ const size_t nelements = (end + tsize - 1)/tsize;
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
int32_t masked_i = masked ? 1 : 0;
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@@ -9006,8 +9276,9 @@ static void ggml_compute_forward_add_q_f32(
const int nth = params->nth;
const enum ggml_type type = src0->type;
+ const enum ggml_type dtype = dst->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -9019,7 +9290,6 @@ static void ggml_compute_forward_add_q_f32(
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ggml_is_quantized(src0->type));
- GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// rows per thread
@@ -9057,7 +9327,11 @@ static void ggml_compute_forward_add_q_f32(
// add src1
ggml_vec_acc_f32(ne00, wdata, src1_row);
// quantize row to dst
- quantize_row_q(wdata, dst_row, ne00);
+ if (quantize_row_q != NULL) {
+ quantize_row_q(wdata, dst_row, ne00);
+ } else {
+ memcpy(dst_row, wdata, ne0*nb0);
+ }
}
}
@@ -10153,11 +10427,61 @@ static void ggml_compute_forward_repeat_f32(
}
}
+static void ggml_compute_forward_repeat_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(params->ith == 0);
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nr0 = (int)(ne0/ne00);
+ const int nr1 = (int)(ne1/ne01);
+ const int nr2 = (int)(ne2/ne02);
+ const int nr3 = (int)(ne3/ne03);
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // TODO: maybe this is not optimal?
+ for (int i3 = 0; i3 < nr3; i3++) {
+ for (int k3 = 0; k3 < ne03; k3++) {
+ for (int i2 = 0; i2 < nr2; i2++) {
+ for (int k2 = 0; k2 < ne02; k2++) {
+ for (int i1 = 0; i1 < nr1; i1++) {
+ for (int k1 = 0; k1 < ne01; k1++) {
+ for (int i0 = 0; i0 < nr0; i0++) {
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
+ // ggml_vec_cpy_f16(ne00, y, x)
+ for (int i = 0; i < ne00; ++i) {
+ y[i] = x[i];
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
static void ggml_compute_forward_repeat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_repeat_f16(params, src0, dst);
+ } break;
case GGML_TYPE_F32:
{
ggml_compute_forward_repeat_f32(params, src0, dst);
@@ -11497,8 +11821,8 @@ static void ggml_compute_forward_out_prod_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- int64_t t0 = ggml_perf_time_us();
- UNUSED(t0);
+ // int64_t t0 = ggml_perf_time_us();
+ // UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS;
@@ -11539,6 +11863,146 @@ static void ggml_compute_forward_out_prod_f32(
return;
}
+ // dst[:,:,:,:] = 0
+ // for i2,i3:
+ // for i1:
+ // for i01:
+ // for i0:
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+ // parallelize by last three dimensions
+
+ // total rows in dst
+ const int64_t nr = ne1*ne2*ne3;
+
+ // rows per thread
+ const int64_t dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int64_t ir0 = dr*ith;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ // block-tiling attempt
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+ const int64_t blck_1 = 16;
+
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+ for (int64_t ir = bir; ir < bir1; ++ir) {
+ // dst indices
+ const int64_t i3 = ir/(ne2*ne1);
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ //const int64_t i10 = i1;
+ const int64_t i12 = i2;
+ const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+ }
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
+ }
+#else
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
+ }
+#endif
+ }
+ }
+ }
+
+
+ //int64_t t1 = ggml_perf_time_us();
+ //static int64_t acc = 0;
+ //acc += t1 - t0;
+ //if (t1 - t0 > 10) {
+ // printf("\n");
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+ //}
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ // int64_t t0 = ggml_perf_time_us();
+ // UNUSED(t0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+
+ GGML_ASSERT(ne02 == ne12);
+ GGML_ASSERT(ne03 == ne13);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ // we don't support permuted src0 dim0
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+
+ // dst dim0 cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ // GGML_ASSERT(nb0 <= nb1);
+ // GGML_ASSERT(nb1 <= nb2);
+ // GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+
+ if (params->type == GGML_TASK_INIT) {
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
// parallelize by last three dimensions
// total rows in dst
@@ -11558,6 +12022,8 @@ static void ggml_compute_forward_out_prod_f32(
// for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
@@ -11578,10 +12044,8 @@ static void ggml_compute_forward_out_prod_f32(
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
- ggml_vec_mad_f32(ne0, d, s0, *s1);
- // for (int64_t i0 = 0; i0 < ne0; ++i0) {
- // d[i0] += s0[i0] * s1[i1];
- // }
+ dequantize_row_q(s0, wdata, ne0);
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
}
}
@@ -11610,10 +12074,13 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
{
- GGML_ASSERT(false); // todo
- // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
+ ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
@@ -12001,14 +12468,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
- GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_contiguous(dst));
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+ if (params->type == GGML_TASK_INIT) {
+ memset(dst->data, 0, ggml_nbytes(dst));
+ }
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@@ -12034,11 +12502,8 @@ static void ggml_compute_forward_get_rows_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
- GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_contiguous(dst));
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
@@ -12072,16 +12537,15 @@ static void ggml_compute_forward_get_rows_back(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
- ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
+ ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
- ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
+ ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
} break;
default:
{
@@ -14143,10 +14607,11 @@ static void ggml_compute_forward_flash_attn_f32(
S[i] = -INFINITY;
}
- for (int64_t ic = 0; ic < nek1; ++ic) {
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
// k indices
const int ik3 = iq3;
- const int ik2 = iq2;
+ const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
@@ -14159,20 +14624,18 @@ static void ggml_compute_forward_flash_attn_f32(
}
// scale
- ggml_vec_scale_f32(nek1, S, scale);
+ ggml_vec_scale_f32(masked_begin, S, scale);
- if (masked) {
- for (int64_t i = P; i < M; i++) {
- if (i > P + iq1) {
- S[i] = -INFINITY;
- }
- }
+ for (int64_t i = masked_begin; i < M; i++) {
+ S[i] = -INFINITY;
}
// softmax
+ // exclude known -INF S[..] values from max and loop
+ // dont forget to set their SW values to zero
{
float max = -INFINITY;
- ggml_vec_max_f32(M, &max, S);
+ ggml_vec_max_f32(masked_begin, &max, S);
ggml_float sum = 0.0;
{
@@ -14186,10 +14649,15 @@ static void ggml_compute_forward_flash_attn_f32(
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+ if (i >= masked_begin) {
+ break;
+ }
float * SS = S + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
- if (SS[j] == -INFINITY) {
+ if (i + j >= masked_begin) {
+ break;
+ } else if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
#ifndef GGML_FLASH_ATTN_EXP_FP16
@@ -14214,10 +14682,10 @@ static void ggml_compute_forward_flash_attn_f32(
assert(sum > 0.0);
sum = 1.0/sum;
- ggml_vec_scale_f32(M, S, sum);
+ ggml_vec_scale_f32(masked_begin, S, sum);
#ifndef NDEBUG
- for (int i = 0; i < M; ++i) {
+ for (int i = 0; i < masked_begin; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
@@ -14230,9 +14698,13 @@ static void ggml_compute_forward_flash_attn_f32(
const int i2 = iq2;
const int i3 = iq3;
- ggml_vec_dot_f32(nek1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ // v indices
+ const int iv2 = iq2 % nev2;
+ const int iv3 = iq3;
+
+ ggml_vec_dot_f32(masked_begin,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S);
}
}
@@ -14329,7 +14801,7 @@ static void ggml_compute_forward_flash_attn_f16(
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
- const int ik2 = iq2;
+ const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
@@ -14344,7 +14816,7 @@ static void ggml_compute_forward_flash_attn_f16(
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
// k indices
const int ik3 = iq3;
- const int ik2 = iq2;
+ const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
@@ -14369,6 +14841,8 @@ static void ggml_compute_forward_flash_attn_f16(
}
// softmax
+ // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
+ // dont forget to set their S values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
@@ -14425,6 +14899,7 @@ static void ggml_compute_forward_flash_attn_f16(
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
+ // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices
@@ -14432,9 +14907,13 @@ static void ggml_compute_forward_flash_attn_f16(
const int i2 = iq2;
const int i3 = iq3;
- ggml_vec_dot_f16(nek1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ // v indices
+ const int iv2 = iq2 % nev2;
+ const int iv3 = iq3;
+
+ ggml_vec_dot_f16(nev0,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
} else {
@@ -14444,9 +14923,13 @@ static void ggml_compute_forward_flash_attn_f16(
const int i2 = iq2;
const int i3 = iq3;
- ggml_vec_dot_f16_unroll(nek1, nbv1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ // v indices
+ const int iv2 = iq2 % nev2;
+ const int iv3 = iq3;
+
+ ggml_vec_dot_f16_unroll(nev0, nbv1,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
}
@@ -14705,10 +15188,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
return;
}
- // parallelize by q rows using ggml_vec_dot_f32
+ const int64_t elem_q = ggml_nelements(q);
+ const int64_t elem_k = ggml_nelements(k);
- // total rows in q
- const int nr = neq2*neq3;
+ enum ggml_type result_type = dst->type;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+ void * grad_q = (char *) dst->data;
+ void * grad_k = (char *) dst->data + offs_k;
+ void * grad_v = (char *) dst->data + offs_v;
+
+ const size_t nbgq1 = nb0*neq0;
+ const size_t nbgq2 = nb0*neq0*neq1;
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+ const size_t nbgk1 = nb0*nek0;
+ const size_t nbgk2 = nb0*nek0*nek1;
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+ const size_t nbgv1 = nb0*nev0;
+ const size_t nbgv2 = nb0*nev0*nev1;
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+ // parallelize by k rows using ggml_vec_dot_f32
+
+ // total rows in k
+ const int nr = nek2*nek3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
@@ -14721,268 +15231,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+ // how often k2 (and v2) is repeated in q2
+ int nrep = neq2/nek2;
+
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
- const int iq3 = ir/(neq2);
- const int iq2 = ir - iq3*neq2;
- for ( int iq1 = 0; iq1 < neq1; ++iq1) {
+ const int ik3 = ir/(nek2);
+ const int ik2 = ir - ik3*nek2;
+ const int iq3 = ik3;
+ const int id3 = ik3;
+ const int iv3 = ik3;
+ const int iv2 = ik2;
- // not sure about CACHE_LINE_SIZE_F32..
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+ for (int irep = 0; irep < nrep; ++irep) {
+ const int iq2 = ik2 + irep*nek2;
+ const int id2 = iq2;
- for (int i = M; i < Mup; ++i) {
- S[i] = -INFINITY;
- }
+ // (ik2 + irep*nek2) % nek2 == ik2
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
+ const int id1 = iq1;
- for (int64_t ic = 0; ic < nek1; ++ic) {
- // k indices
- const int ik3 = iq3;
- const int ik2 = iq2;
- const int ik1 = ic;
+ // not sure about CACHE_LINE_SIZE_F32..
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
- // S indices
- const int i1 = ik1;
+ for (int i = M; i < Mup; ++i) {
+ S[i] = -INFINITY;
+ }
- ggml_vec_dot_f32(neq0,
- S + i1,
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
- }
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ // k indices
+ const int ik1 = ic;
- // scale
- ggml_vec_scale_f32(nek1, S, scale);
+ // S indices
+ const int i1 = ik1;
- if (masked) {
- for (int64_t i = P; i < M; i++) {
- if (i > P + iq1) {
- S[i] = -INFINITY;
- }
+ ggml_vec_dot_f32(neq0,
+ S + i1,
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
- }
- // softmax
- {
- float max = -INFINITY;
- ggml_vec_max_f32(M, &max, S);
+ // scale
+ ggml_vec_scale_f32(masked_begin, S, scale);
- ggml_float sum = 0.0;
+ for (int64_t i = masked_begin; i < M; i++) {
+ S[i] = -INFINITY;
+ }
+
+ // softmax
+ // exclude known -INF S[..] values from max and loop
+ // dont forget to set their SM values to zero
{
+ float max = -INFINITY;
+ ggml_vec_max_f32(masked_begin, &max, S);
+
+ ggml_float sum = 0.0;
+ {
#ifdef GGML_SOFT_MAX_ACCELERATE
- max = -max;
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
- vvexpf(SM, SM, &Mup);
- ggml_vec_sum_f32(Mup, &sum, SM);
+ max = -max;
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+ vvexpf(SM, SM, &Mup);
+ ggml_vec_sum_f32(Mup, &sum, SM);
#else
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
- float * SR = S + i;
- float * SW = SM + i;
-
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
- if (SR[j] == -INFINITY) {
- SW[j] = 0.0f;
- } else {
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+ if (i >= masked_begin) {
+ break;
+ }
+ float * SR = S + i;
+ float * SW = SM + i;
+
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+ if (i + j >= masked_begin) {
+ break;
+ } else if (SR[j] == -INFINITY) {
+ SW[j] = 0.0f;
+ } else {
#ifndef GGML_FLASH_ATTN_EXP_FP16
- const float val = expf(SR[j] - max);
+ const float val = expf(SR[j] - max);
#else
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
- memcpy(&scvt[j], &s, sizeof(uint16_t));
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
#endif
- sump[j] += (ggml_float)val;
- SW[j] = val;
+ sump[j] += (ggml_float)val;
+ SW[j] = val;
+ }
}
}
- }
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
- sum += sump[i];
- }
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+ sum += sump[i];
+ }
#endif
- }
-
- assert(sum > 0.0);
-
- sum = 1.0/sum;
- ggml_vec_scale_f32(M, SM, sum);
-
- }
-
- // step-by-step explanation
- {
- // forward-process shape grads from backward process
- // parallel_for iq2,iq3:
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
- // for iq1:
- // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
- // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
- // S0 = -Inf [D,1,1,1]
- // ~S1[i] = dot(kcur[:D,i], qcur)
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
- // ~S5[i] = dot(vcur[:,i], S4)
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
- // dst backward-/ grad[dst] = d
- //
- // output gradients with their dependencies:
- //
- // grad[kcur] = grad[S1].T @ qcur
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // grad[S4] = grad[S5] @ vcur
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
- // grad[qcur] = grad[S1] @ kcur
- // grad[vcur] = grad[S5].T @ S4
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
- //
- // in post-order:
- //
- // S1 = qcur @ kcur.T
- // S2 = S1 * scale
- // S3 = diag_mask_inf(S2, P)
- // S4 = softmax(S3)
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
- // grad[qcur] = grad[S1] @ kcur
- // grad[kcur] = grad[S1].T @ qcur
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
- //
- // using less variables (SM=S4):
- //
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
- // SM = softmax(S)
- // S = d[:D,iq1,iq2,iq3] @ vcur
- // dot_SM_gradSM = dot(SM, S)
- // S = SM * (S - dot(SM, S))
- // S = diag_mask_zero(S, P) * scale
- //
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
- }
+ }
- // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
- // S = d[:D,iq1,iq2,iq3] @ vcur
- // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
- ggml_vec_set_f32(M, S, 0);
- for (int64_t ic = 0; ic < D; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ assert(sum > 0.0);
- ggml_vec_mad_f32(M,
- S,
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
- }
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(masked_begin, SM, sum);
- // S = SM * (S - dot(SM, S))
- float dot_SM_gradSM = 0;
- ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
- ggml_vec_mul_f32 (M, S, S, SM);
-
- // S = diag_mask_zero(S, P) * scale
- if (masked) {
- // for (int64_t i = P + iq1 + 1; i < M; i++) {
- // S[i] = 0;
- // }
- for (int64_t i = P; i < M; i++) {
- if (i > P + iq1) {
- S[i] = 0;
- }
}
- }
- ggml_vec_scale_f32(M, S, scale);
-
- void * grad_q = (char *) dst->data;
- void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
- void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
-
- const size_t nbgq1 = nb0*neq0;
- const size_t nbgq2 = nb0*neq0*neq1;
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
-
- const size_t nbgk1 = nb0*nek0;
- const size_t nbgk2 = nb0*nek0*nek1;
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
-
- const size_t nbgv1 = nb0*nev0;
- const size_t nbgv2 = nb0*nev0*nev1;
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
-
- // S shape [M,1]
- // SM shape [M,1]
- // kcur shape [D,M]
- // qcur shape [D,1]
- // vcur shape [M,D]
- //
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
- //
- //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
- //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
- for (int64_t ic = 0; ic < M; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
- ggml_vec_mad_f32(D,
- (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
- (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
- S[ic]);
- }
+ // step-by-step explanation
+ {
+ // forward-process shape grads from backward process
+ // parallel_for ik2,ik3:
+ // for irep:
+ // iq2 = ik2 + irep*nek2
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
+ // for iq1:
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
+ // S0 = -Inf [D,1,1,1]
+ // ~S1[i] = dot(kcur[:D,i], qcur)
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
+ // ~S5[i] = dot(vcur[:,i], S4)
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+ // dst backward-/ grad[dst] = d
+ //
+ // output gradients with their dependencies:
+ //
+ // grad[kcur] = grad[S1].T @ qcur
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // grad[S4] = grad[S5] @ vcur
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
+ // grad[qcur] = grad[S1] @ kcur
+ // grad[vcur] = grad[S5].T @ S4
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+ //
+ // in post-order:
+ //
+ // S1 = qcur @ kcur.T
+ // S2 = S1 * scale
+ // S3 = diag_mask_inf(S2, P)
+ // S4 = softmax(S3)
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
+ // grad[qcur] = grad[S1] @ kcur
+ // grad[kcur] = grad[S1].T @ qcur
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+ //
+ // using less variables (SM=S4):
+ //
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
+ // SM = softmax(S)
+ // S = d[:D,iq1,iq2,iq3] @ vcur
+ // dot_SM_gradSM = dot(SM, S)
+ // S = SM * (S - dot(SM, S))
+ // S = diag_mask_zero(S, P) * scale
+ //
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
+ }
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
- for (int64_t ic = 0; ic < M; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+ // for ic:
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+ // exclude known future zero S[..] values from operation
+ ggml_vec_set_f32(masked_begin, S, 0);
+ for (int64_t ic = 0; ic < D; ++ic) {
+ ggml_vec_mad_f32(masked_begin,
+ S,
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+ }
- // ggml_vec_set_f32(D,
- // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
- // 0);
- ggml_vec_mad_f32(D,
- (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
- (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
- S[ic]);
- }
+ // S = SM * (S - dot(SM, S))
+ float dot_SM_gradSM = 0;
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+ // S = diag_mask_zero(S, P) * scale
+ // already done by above ggml_vec_set_f32
+
+ // exclude known zero S[..] values from operation
+ ggml_vec_scale_f32(masked_begin, S, scale);
+
+ // S shape [M,1]
+ // SM shape [M,1]
+ // kcur shape [D,M]
+ // qcur shape [D,1]
+ // vcur shape [M,D]
+
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+ // for ic:
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+ // exclude known zero S[..] values from loop
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ ggml_vec_mad_f32(D,
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ S[ic]);
+ }
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
- // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
- // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
- for (int64_t ic = 0; ic < D; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+ // for ic:
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
+ // exclude known zero S[..] values from loop
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ ggml_vec_mad_f32(D,
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
+ S[ic]);
+ }
- // ggml_vec_set_f32(M,
- // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
- // 0);
- ggml_vec_mad_f32(M,
- (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
- SM,
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
+ // for ic:
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
+ // exclude known zero SM[..] values from mad
+ for (int64_t ic = 0; ic < D; ++ic) {
+ ggml_vec_mad_f32(masked_begin,
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+ SM,
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+ }
}
}
}
@@ -15896,7 +16381,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_GET_ROWS_BACK:
{
- ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_DIAG:
{
@@ -16069,7 +16554,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
////////////////////////////////////////////////////////////////////////////////
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+
+static size_t hash(void * p) {
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static size_t hash_find(void * hash_table[], void * p) {
+ size_t h = hash(p);
+
+ // linear probing
+ size_t i = h;
+ while (hash_table[i] != NULL && hash_table[i] != p) {
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+ if (i == h) {
+ // visited all hash table entries -> not found
+ return GGML_GRAPH_HASHTABLE_SIZE;
+ }
+ }
+ return i;
+}
+
+static bool hash_insert(void * hash_table[], void * p) {
+ size_t i = hash_find(hash_table, p);
+
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+
+ if (hash_table[i] == p) {
+ return true;
+ }
+
+ // insert
+ GGML_ASSERT(hash_table[i] == NULL);
+ hash_table[i] = p;
+ return false;
+}
+
+static bool hash_contains(void * hash_table[], void * p) {
+ size_t i = hash_find(hash_table, p);
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
+}
+
+struct hash_map {
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
+};
+
+static struct hash_map * new_hash_map(void) {
+ struct hash_map * result = malloc(sizeof(struct hash_map));
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
+ result->keys[i] = NULL;
+ result->vals[i] = NULL;
+ }
+ return result;
+}
+
+static void free_hash_map(struct hash_map * map) {
+ free(map);
+}
+
+// gradient checkpointing
+
+static struct ggml_tensor * ggml_recompute_graph_node(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * graph,
+ struct hash_map * replacements,
+ struct ggml_tensor * node) {
+
+ if (node == NULL) {
+ return NULL;
+ }
+
+ if (node->is_param) {
+ return node;
+ }
+
+ if (!hash_contains(graph->visited_hash_table, node)) {
+ return node;
+ }
+
+ int count_children = 0;
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ if (node->src[k]) {
+ ++count_children;
+ }
+ }
+
+ if (count_children == 0) {
+ return node;
+ }
+
+ size_t i = hash_find(replacements->keys, node);
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+ if (replacements->keys[i] == node) {
+ return (struct ggml_tensor *) replacements->vals[i];
+ }
+
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
+
+ // insert clone into replacements
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
+ replacements->keys[i] = node;
+ replacements->vals[i] = clone;
+
+ clone->op = node->op;
+ clone->grad = node->grad;
+ clone->is_param = node->is_param;
+ clone->extra = node->extra;
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
+ clone->nb[k] = node->nb[k];
+ }
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
+ }
+ if (node->view_src != NULL) {
+ clone->data = (node->view_src->data == NULL)
+ ? NULL // view_src not yet allocated
+ : (char *) node->view_src->data // view_src already allocated
+ + node->view_offs;
+ clone->view_src = node->view_src;
+ clone->view_offs = node->view_offs;
+ }
+
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
+
+ return clone;
+}
+
+void ggml_build_backward_gradient_checkpointing(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * checkpoints,
+ int n_checkpoints) {
+ *gb_tmp = *gf;
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+
+ if (n_checkpoints <= 0) {
+ *gb = *gb_tmp;
+ return;
+ }
+
+ struct hash_map * replacements = new_hash_map();
+
+ // insert checkpoints in replacements
+ for (int i = 0; i < n_checkpoints; ++i) {
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
+ replacements->keys[k] = checkpoints[i];
+ replacements->vals[k] = checkpoints[i];
+ }
+
+ *gb = *gf;
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
+ // by recomputing them from checkpoints
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
+ struct ggml_tensor * node = gb_tmp->nodes[i];
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ // insert new tensors recomputing src, reusing already made replacements,
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
+ // recurse for input tensors,
+ // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
+ }
+ // insert rewritten backward node with replacements made into resulting backward graph gb
+ ggml_build_forward_expand(gb, node);
+ }
+
+ free_hash_map(replacements);
+}
+
+// functions to change gradients considering the case that input a might be initial gradient with zero value
+
+static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ return b;
+ } else {
+ return ggml_add_impl(ctx, a, b, false);
+ }
+}
+
+static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
+ } else {
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+ }
+}
+
+static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ return ggml_repeat(ctx, b, a);
+ } else {
+ return ggml_add1_impl(ctx, a, b, false);
+ }
+}
+
+static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ return ggml_neg(ctx, b);
+ } else {
+ return ggml_sub_impl(ctx, a, b, false);
+ }
+}
+
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
@@ -16077,34 +16773,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_DUP:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_ADD:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
- src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_ADD1:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
- src1->grad = ggml_add_impl(ctx,
+ src1->grad = ggml_add_or_set(ctx,
src1->grad,
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
- inplace);
+ zero_table);
}
} break;
case GGML_OP_ACC:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -16121,117 +16817,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
nb1, nb2, nb3, offset);
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view),
src1->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SUB:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
- src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_MUL:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx, src1, tensor->grad),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_mul(ctx, src0, tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_DIV:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_div(ctx, tensor->grad, src1),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_sub_impl(ctx,
+ ggml_sub_or_set(ctx,
src1->grad,
ggml_mul(ctx,
tensor->grad,
ggml_div(ctx, tensor, src1)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SQR:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
ggml_new_f32(ctx, 2.0f)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SQRT:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_scale(ctx,
ggml_div(ctx,
tensor->grad,
tensor),
ggml_new_f32(ctx, 0.5f)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_LOG:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_div(ctx,
tensor->grad,
src0),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SUM:
{
if (src0->grad) {
src0->grad =
- ggml_add1_impl(ctx,
+ ggml_add1_or_set(ctx,
src0->grad,
tensor->grad,
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SUM_ROWS:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_repeat(ctx,
tensor->grad,
src0->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_MEAN:
@@ -16243,20 +16939,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_repeat_back(ctx, tensor->grad, src0->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_REPEAT_BACK:
{
if (src0->grad) {
// TODO: test this
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_repeat(ctx, tensor->grad, src0->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_CONCAT:
@@ -16278,10 +16974,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
float eps;
memcpy(&eps, tensor->op_params, sizeof(float));
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_RMS_NORM_BACK:
@@ -16305,37 +17001,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
// ds1 = t.T.dot(dt)
- // tensor.shape [m,p]
- // src0.shape [n,m]
- // src1.shape [n,p]
+ // tensor.shape [m,p,qq,rr]
+ // src0.shape [n,m,q1,r1]
+ // src1.shape [n,p,qq,rr]
// necessary for llama
if (src0->grad) {
+ struct ggml_tensor * s1_tg =
+ ggml_out_prod(ctx, // [n,m,qq,rr]
+ src1, // [n,p,qq,rr]
+ tensor->grad); // [m,p,qq,rr]
+ const int64_t qq = s1_tg->ne[2];
+ const int64_t rr = s1_tg->ne[3];
+ const int64_t q1 = src0->ne[2];
+ const int64_t r1 = src0->ne[3];
+ const bool ne2_broadcasted = qq > q1;
+ const bool ne3_broadcasted = rr > r1;
+ if (ne2_broadcasted || ne3_broadcasted) {
+ // sum broadcast repetitions of s1_tg into shape of src0
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+ }
src0->grad =
- ggml_add_impl(ctx,
- src0->grad,
- ggml_out_prod(ctx, // [n,m]
- src1, // [n,p]
- tensor->grad), // [m,p]
- inplace);
+ ggml_add_or_set(ctx,
+ src0->grad, // [n,m,q1,r1]
+ s1_tg, // [n,m,q1,r1]
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
- src1->grad,
- // ggml_mul_mat(ctx, // [n,p]
- // ggml_cont(ctx, // [m,n]
- // ggml_transpose(ctx, src0)), // [m,n]
- // tensor->grad), // [m,p]
+ ggml_add_or_set(ctx,
+ src1->grad, // [n,p,qq,rr]
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
+ // ggml_cont(ctx, // [m,n,q1,r1]
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+ // tensor->grad), // [m,p,qq,rr]
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
// // avoid transpose of src0, rather transpose smaller tensor->grad
// // and then use ggml_out_prod
- ggml_out_prod(ctx, // [n,p]
- src0, // [n,m]
- ggml_transpose(ctx, // [p,m]
- tensor->grad)), // [m,p]
- inplace);
+ ggml_out_prod(ctx, // [n,p,qq,rr]
+ src0, // [n,m,q1,r1]
+ ggml_transpose(ctx, // [p,m,qq,rr]
+ tensor->grad)), // [m,p,qq,rr]
+ zero_table);
}
} break;
case GGML_OP_OUT_PROD:
@@ -16347,17 +17055,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_scale_impl(ctx, tensor->grad, src1, false),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SET:
@@ -16384,23 +17092,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
}
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_acc_impl(ctx,
tensor->grad,
ggml_neg(ctx, tensor_grad_view),
nb1, nb2, nb3, offset, false),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view),
src1->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_CPY:
@@ -16411,7 +17119,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// tensor = src0 * 1 + src1 * 0
if (src0->grad) {
// dsrc0 = dtensor * 1
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
// dsrc1 = dtensor * 0 -> noop
@@ -16423,7 +17131,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) {
GGML_ASSERT(ggml_is_contiguous(src0->grad));
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_RESHAPE:
@@ -16431,9 +17139,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
- ggml_reshape(ctx, tensor->grad, src0->grad),
- inplace);
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_reshape(ctx,
+ ggml_is_contiguous(tensor->grad)
+ ? tensor->grad
+ : ggml_cont(ctx, tensor->grad),
+ src0->grad),
+ zero_table);
}
} break;
case GGML_OP_VIEW:
@@ -16462,7 +17174,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
nb3 = (nb3 / n0) * ng;
}
- src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
}
} break;
case GGML_OP_PERMUTE:
@@ -16480,14 +17192,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
axes_backward[axis2] = 2;
axes_backward[axis3] = 3;
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_permute(ctx,
tensor->grad,
axes_backward[0],
axes_backward[1],
axes_backward[2],
axes_backward[3]),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_TRANSPOSE:
@@ -16495,9 +17207,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_transpose(ctx, tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_GET_ROWS:
@@ -16505,9 +17217,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama (only for tokenizer)
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
+ // last ggml_get_rows_back argument src0->grad is only
+ // necessary to setup correct output shape
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
- inplace);
+ zero_table);
}
if (src1->grad) {
// noop
@@ -16527,9 +17241,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_DIAG_MASK_ZERO:
@@ -16538,9 +17252,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SOFT_MAX:
@@ -16548,9 +17262,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_soft_max_back(ctx, tensor->grad, tensor),
- inplace);
+ zero_table);
}
} break;
@@ -16575,7 +17289,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
@@ -16587,7 +17301,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
freq_scale,
xpos_base,
xpos_down),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_ROPE_BACK:
@@ -16606,7 +17320,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rope_impl(ctx,
tensor->grad,
@@ -16619,7 +17333,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
xpos_base,
xpos_down,
false),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_ALIBI:
@@ -16670,145 +17384,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
masked);
}
- if (src0->grad) {
- struct ggml_tensor * grad_q = NULL;
- const size_t nb0 = flash_grad->nb[0];
- const size_t offset = 0;
- switch(src0->n_dims) {
- case 2:
- {
- grad_q = ggml_view_2d(ctx,
- flash_grad,
- src0->ne[0],
- src0->ne[1],
- nb0*src0->ne[0],
- offset);
- } break;
- case 3:
- {
- grad_q = ggml_view_3d(ctx,
- flash_grad,
- src0->ne[0],
- src0->ne[1],
- src0->ne[2],
- nb0*src0->ne[0],
- nb0*src0->ne[0]*src0->ne[1],
- offset);
- } break;
- case 4:
- {
- grad_q = ggml_view_4d(ctx,
- flash_grad,
- src0->ne[0],
- src0->ne[1],
- src0->ne[2],
- src0->ne[3],
- nb0*src0->ne[0],
- nb0*src0->ne[0]*src0->ne[1],
- nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
- offset);
- } break;
- }
+ struct ggml_tensor * src2 = tensor->src[2];
+ const int64_t elem_q = ggml_nelements(src0);
+ const int64_t elem_k = ggml_nelements(src1);
+ const int64_t elem_v = ggml_nelements(src2);
- src0->grad = ggml_add_impl(ctx,
+ enum ggml_type result_type = flash_grad->type;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+ if (src0->grad) {
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
grad_q,
- inplace);
+ zero_table);
}
-
if (src1->grad) {
- struct ggml_tensor * grad_k = NULL;
- const size_t nb0 = flash_grad->nb[0];
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
- switch(src1->n_dims) {
- case 2:
- {
- grad_k = ggml_view_2d(ctx,
- flash_grad,
- src1->ne[0],
- src1->ne[1],
- nb0*src1->ne[0],
- offset);
- } break;
- case 3:
- {
- grad_k = ggml_view_3d(ctx,
- flash_grad,
- src1->ne[0],
- src1->ne[1],
- src1->ne[2],
- nb0*src1->ne[0],
- nb0*src1->ne[0]*src1->ne[1],
- offset);
- } break;
- case 4:
- {
- grad_k = ggml_view_4d(ctx,
- flash_grad,
- src1->ne[0],
- src1->ne[1],
- src1->ne[2],
- src1->ne[3],
- nb0*src1->ne[0],
- nb0*src1->ne[0]*src1->ne[1],
- nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
- offset);
- } break;
- }
-
- src1->grad = ggml_add_impl(ctx,
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
+ src1->grad = ggml_add_or_set(ctx,
src1->grad,
grad_k,
- inplace);
+ zero_table);
}
-
- struct ggml_tensor * opt0 = tensor->src[2];
-
- if (opt0->grad) {
- struct ggml_tensor * grad_v = NULL;
- const size_t nb0 = flash_grad->nb[0];
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
- + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
- switch(opt0->n_dims) {
- case 2:
- {
- grad_v = ggml_view_2d(ctx,
- flash_grad,
- opt0->ne[0],
- opt0->ne[1],
- nb0*opt0->ne[0],
- offset);
- } break;
- case 3:
- {
- grad_v = ggml_view_3d(ctx,
- flash_grad,
- opt0->ne[0],
- opt0->ne[1],
- opt0->ne[2],
- nb0*opt0->ne[0],
- nb0*opt0->ne[0]*opt0->ne[1],
- offset);
- } break;
- case 4:
- {
- grad_v = ggml_view_4d(ctx,
- flash_grad,
- opt0->ne[0],
- opt0->ne[1],
- opt0->ne[2],
- opt0->ne[3],
- nb0*opt0->ne[0],
- nb0*opt0->ne[0]*opt0->ne[1],
- nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
- offset);
- } break;
- }
-
- opt0->grad = ggml_add_impl(ctx,
- opt0->grad,
+ if (src2->grad) {
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
+ src2->grad = ggml_add_or_set(ctx,
+ src2->grad,
grad_v,
- inplace);
+ zero_table);
}
} break;
case GGML_OP_FLASH_FF:
@@ -16828,12 +17439,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx,
ggml_sgn(ctx, src0),
tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_UNARY_OP_SGN:
@@ -16845,7 +17456,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_UNARY_OP_NEG:
{
if (src0->grad) {
- src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_UNARY_OP_STEP:
@@ -16865,12 +17476,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_UNARY_OP_RELU:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx,
ggml_step(ctx, src0),
tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_UNARY_OP_GELU:
@@ -16885,10 +17496,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_silu_back(ctx, src0, tensor->grad),
- inplace);
+ zero_table);
}
} break;
default:
@@ -16911,13 +17522,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_CROSS_ENTROPY_LOSS:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_cross_entropy_loss_back(ctx,
src0,
src1,
tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -16933,34 +17544,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
GGML_ASSERT(false);
} break;
}
-}
-
-static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
-
-static size_t hash(void * p) {
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
-
-static bool hash_insert(void * hash_table[], void * p) {
- size_t h = hash(p);
- // linear probing
- size_t i = h;
- while (hash_table[i] != NULL && hash_table[i] != p) {
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
- if (i == h) {
- // hash table is full
- GGML_ASSERT(false);
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (tensor->src[i] && tensor->src[i]->grad) {
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
}
}
-
- if (hash_table[i] == p) {
- return true;
- }
-
- // insert
- hash_table[i] = p;
- return false;
}
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
@@ -16978,8 +17567,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
}
for (int i = 0; i < GGML_MAX_SRC; ++i) {
- if (node->src[i]) {
- ggml_visit_parents(cgraph, node->src[i]);
+ const int k =
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
+ /* unknown order, just fall back to using i*/ i;
+ if (node->src[k]) {
+ ggml_visit_parents(cgraph, node->src[k]);
}
}
@@ -17038,6 +17631,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
/*.grads =*/ { NULL },
/*.leafs =*/ { NULL },
/*.hash_table =*/ { NULL },
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
@@ -17063,12 +17657,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
}
}
+ // remember original gradients which start with zero values
+ void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
+ memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
+ for (int i = 0; i < gf->n_nodes; i++) {
+ if (gf->grads[i]) {
+ hash_insert(zero_table, gf->grads[i]);
+ }
+ }
+
for (int i = gf->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = gf->nodes[i];
- // because we detached the grad nodes from the original graph, we can afford inplace operations
+ // inplace operations to add gradients are not created by ggml_compute_backward
+ // use allocator to automatically make inplace operations
if (node->grad) {
- ggml_compute_backward(ctx, node, keep);
+ ggml_compute_backward(ctx, node, zero_table);
}
}
@@ -17080,6 +17684,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
ggml_build_forward_expand(gb, node->grad);
}
}
+
+ free(zero_table);
}
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@@ -17099,6 +17705,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
/*.grads =*/ { NULL },
/*.leafs =*/ { NULL },
/*.hash_table =*/ { NULL },
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
@@ -17489,7 +18096,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
- case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;
@@ -17533,6 +18139,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
work_size = MAX(work_size, cur);
} break;
+ case GGML_OP_OUT_PROD:
+ {
+ n_tasks = n_threads;
+
+ size_t cur = 0;
+
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
case GGML_OP_SCALE:
{
n_tasks = 1;
@@ -18624,7 +19242,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
}
static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
- int i = 0;
+ int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_nelements(ps[p]) ;
// TODO: add function to get all elements at once
@@ -18634,6 +19252,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
}
}
+static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
+ int64_t i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int64_t j = 0; j < ne; ++j) {
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
+ }
+ }
+}
+
//
// ADAM
//
@@ -18682,26 +19311,43 @@ static enum ggml_opt_result ggml_opt_adam(
const float eps = params.adam.eps;
const float gclip = params.adam.gclip;
const int decay_min_ndim = params.adam.decay_min_ndim;
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+ float * g = opt->adam.g->data; // gradients
float * m = opt->adam.m->data; // first moment
float * v = opt->adam.v->data; // second moment
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
- if (callback) {
- callback(callback_data, &sched);
- }
-
- // compute the function value
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
- ggml_graph_compute(gb, &cplan);
- opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
+ bool cancel = false;
+
+ // compute the function value
+ float fx = 0;
+ ggml_set_zero(opt->adam.g);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ if (cancel) {
+ return GGML_OPT_DID_NOT_CONVERGE;
+ }
+ fx *= accum_norm;
+
+ opt->adam.fx_prev = fx;
opt->adam.fx_best = opt->adam.fx_prev;
if (pf) {
pf[opt->iter % params.past] = opt->adam.fx_prev;
@@ -18724,6 +19370,9 @@ static enum ggml_opt_result ggml_opt_adam(
// run the optimizer
for (int t = 0; t < params.adam.n_iter; ++t) {
+ if (cancel) {
+ break;
+ }
opt->iter = iter0 + t + 1;
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
@@ -18746,12 +19395,8 @@ static enum ggml_opt_result ggml_opt_adam(
if (gclip > 0.0f) {
// gradient clipping
ggml_float sum = 0.0;
- for (int p = 0; p < np; ++p) {
- const int64_t ne = ggml_nelements(ps[p]);
- for (int64_t j = 0; j < ne; ++j) {
- float g = ggml_get_f32_1d(ps[p]->grad, j);
- sum += (ggml_float)(g*g);
- }
+ for (int64_t i = 0; i < nx; ++i) {
+ sum += (ggml_float)(g[i]*g[i]);
}
ggml_float norm = sqrt(sum);
if (norm > (ggml_float) gclip) {
@@ -18765,10 +19410,10 @@ static enum ggml_opt_result ggml_opt_adam(
const int64_t ne = ggml_nelements(ps[p]);
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
for (int64_t j = 0; j < ne; ++j) {
- float x = ggml_get_f32_1d(ps[p], j);
- float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
- m[i] = m[i]*beta1 + g*(1.0f - beta1);
- v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
+ float x = ggml_get_f32_1d(ps[p], j);
+ float g_ = g[i]*gnorm;
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
float mh = m[i]*beta1h;
float vh = v[i]*beta2h;
vh = sqrtf(vh) + eps;
@@ -18779,16 +19424,26 @@ static enum ggml_opt_result ggml_opt_adam(
}
}
- if (callback) {
- callback(callback_data, &sched);
+ fx = 0;
+ ggml_set_zero(opt->adam.g);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
}
+ if (cancel) {
+ break;
+ }
+ fx *= accum_norm;
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
- ggml_graph_compute(gb, &cplan);
-
- const float fx = ggml_get_f32_1d(f, 0);
opt->loss_after = fx;
@@ -18868,11 +19523,11 @@ static enum ggml_opt_result linesearch_backtracking(
float * step,
const float * xp,
struct ggml_tensor * f,
- struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cplan * cplan,
const int np,
struct ggml_tensor * ps[],
+ bool * cancel,
ggml_opt_callback callback,
void * callback_data) {
int count = 0;
@@ -18886,6 +19541,9 @@ static enum ggml_opt_result linesearch_backtracking(
const float dec = 0.5f;
const float inc = 2.1f;
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
if (*step <= 0.f) {
return GGML_LINESEARCH_INVALID_PARAMETERS;
}
@@ -18902,13 +19560,7 @@ static enum ggml_opt_result linesearch_backtracking(
finit = *fx;
dgtest = params->lbfgs.ftol*dginit;
- while (true) {
- if (callback) {
- // LBFG-S does not support learning rate -> ignore learning schedule
- float sched = 0;
- callback(callback_data, &sched);
- }
-
+ while (!*cancel) {
ggml_vec_cpy_f32(nx, x, xp);
ggml_vec_mad_f32(nx, x, d, *step);
@@ -18916,14 +19568,28 @@ static enum ggml_opt_result linesearch_backtracking(
{
ggml_opt_set_params(np, ps, x);
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
- ggml_graph_compute(gb, cplan);
-
- ggml_opt_get_grad(np, ps, g);
+ *fx = 0;
+ memset(g, 0, sizeof(float)*nx);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ // LBFG-S does not support learning rate -> ignore learning schedule
+ float sched = 0;
+ callback(callback_data, accum_step, &sched, cancel);
+ if (*cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ *fx += ggml_get_f32_1d(f, 0);
+ }
+ if (*cancel) {
+ break;
+ }
+ *fx *= accum_norm;
- *fx = ggml_get_f32_1d(f, 0);
}
++count;
@@ -19024,6 +19690,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
float fx = 0.0f; // cost function value
float xnorm = 0.0f; // ||x||
float gnorm = 0.0f; // ||g||
@@ -19037,24 +19706,33 @@ static enum ggml_opt_result ggml_opt_lbfgs(
float * lm_s = opt->lbfgs.lms->data;
float * lm_y = opt->lbfgs.lmy->data;
- if (callback) {
- // LBFG-S does not support learning rate -> ignore learning schedule
- float sched = 0;
- callback(callback_data, &sched);
- }
+ bool cancel = false;
// evaluate the function value and its gradient
{
ggml_opt_set_params(np, ps, x);
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
- ggml_graph_compute(gb, &cplan);
-
- ggml_opt_get_grad(np, ps, g);
-
- fx = ggml_get_f32_1d(f, 0);
+ fx = 0;
+ memset(g, 0, sizeof(float)*nx);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ // LBFG-S does not support learning rate -> ignore learning schedule
+ float sched = 0;
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ if (cancel) {
+ return GGML_OPT_DID_NOT_CONVERGE;
+ }
+ fx *= accum_norm;
opt->loss_before = fx;
opt->loss_after = fx;
@@ -19112,7 +19790,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_vec_cpy_f32(nx, xp, x);
ggml_vec_cpy_f32(nx, gp, g);
- ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
+ ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
+ if (!cancel) {
+ break;
+ }
if (ls < 0) {
// linesearch failed - go back to the previous point and return
@@ -19241,6 +19922,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
.print_forward_graph = true,
.print_backward_graph = true,
+ .n_gradient_accumulation = 1,
+
.adam = {
.n_iter = 10000,
.sched = 1.000f,
@@ -19269,6 +19952,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
.print_forward_graph = true,
.print_backward_graph = true,
+ .n_gradient_accumulation = 1,
+
.lbfgs = {
.m = 6,
.n_iter = 100,
@@ -19299,13 +19984,32 @@ GGML_API void ggml_opt_init(
opt->iter = 0;
opt->nx = nx;
opt->just_initialized = true;
+ if (opt->ctx == NULL) {
+ struct ggml_init_params ctx_opt_params;
+ if (opt->params.type == GGML_OPT_ADAM) {
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
+ if (opt->params.past > 0) {
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+ }
+ } else if (opt->params.type == GGML_OPT_LBFGS) {
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
+ if (opt->params.past > 0) {
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+ }
+ }
+ ctx_opt_params.mem_buffer = NULL;
+ ctx_opt_params.no_alloc = false;
+
+ opt->ctx = ggml_init(ctx_opt_params);
+ }
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
- opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->adam.pf = params.past > 0
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL;
ggml_set_zero(opt->adam.m);
ggml_set_zero(opt->adam.v);
@@ -19315,18 +20019,18 @@ GGML_API void ggml_opt_init(
} break;
case GGML_OPT_LBFGS:
{
- opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.pf = params.past > 0
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL;
- opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
- opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
- opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
- opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
ggml_set_zero(opt->lbfgs.x);
ggml_set_zero(opt->lbfgs.xp);
ggml_set_zero(opt->lbfgs.g);