summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c77
1 files changed, 61 insertions, 16 deletions
diff --git a/ggml.c b/ggml.c
index c522a101..e2687ef4 100644
--- a/ggml.c
+++ b/ggml.c
@@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
bool inplace) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ if (mask) {
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(mask->ne[2] == 1);
+ GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+ }
+
bool is_node = false;
if (a->grad) {
@@ -4835,9 +4845,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ float params[] = { scale };
+ ggml_set_op_params(result, params, sizeof(params));
+
result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
+ result->src[1] = mask;
return result;
}
@@ -4845,13 +4859,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, false);
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
}
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, true);
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale) {
+ return ggml_soft_max_impl(ctx, a, mask, scale, false);
}
// ggml_soft_max_back
@@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
- struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
+ float scale = 1.0f;
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
const int nth = params->nth;
+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
@@ -10575,29 +10602,40 @@ static void ggml_compute_forward_soft_max_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
for (int i1 = ir0; i1 < ir1; i1++) {
- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+ // broadcast the mask across rows
+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+
+ ggml_vec_cpy_f32 (nc, wp, sp);
+ ggml_vec_scale_f32(nc, wp, scale);
+ if (mp) {
+ ggml_vec_acc_f32(nc, wp, mp);
+ }
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(sp[i]));
+ assert(!isnan(wp[i]));
}
#endif
float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, sp);
+ ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = 0.0;
uint16_t scvt;
for (int i = 0; i < nc; i++) {
- if (sp[i] == -INFINITY) {
+ if (wp[i] == -INFINITY) {
dp[i] = 0.0f;
} else {
- // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
- ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
+ // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
sum += (ggml_float)val;
@@ -10622,11 +10660,12 @@ static void ggml_compute_forward_soft_max_f32(
static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
- struct ggml_tensor * dst) {
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
- ggml_compute_forward_soft_max_f32(params, src0, dst);
+ ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
} break;
default:
{
@@ -13863,7 +13902,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX:
{
- ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
@@ -15899,6 +15938,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
+ case GGML_OP_SOFT_MAX:
+ {
+ n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
+
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+ } break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
GGML_ASSERT(node->src[0]->ne[3] == 1);