diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 77 |
1 files changed, 61 insertions, 16 deletions
@@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace( static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, bool inplace) { + GGML_ASSERT(ggml_is_contiguous(a)); + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + } + bool is_node = false; if (a->grad) { @@ -4835,9 +4845,13 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + result->op = GGML_OP_SOFT_MAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = mask; return result; } @@ -4845,13 +4859,21 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true); +} + +struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale) { + return ggml_soft_max_impl(ctx, a, mask, scale, false); } // ggml_soft_max_back @@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero( static void ggml_compute_forward_soft_max_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(ggml_is_contiguous(dst)); + assert(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + // TODO: handle transposed/permuted matrices const int ith = params->ith; const int nth = params->nth; + const int64_t ne11 = src1 ? src1->ne[1] : 1; + const int nc = src0->ne[0]; const int nr = ggml_nrows(src0); @@ -10575,29 +10602,40 @@ static void ggml_compute_forward_soft_max_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + for (int i1 = ir0; i1 < ir1; i1++) { - float *sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (mp) { + ggml_vec_acc_f32(nc, wp, mp); + } #ifndef NDEBUG for (int i = 0; i < nc; ++i) { //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(sp[i])); + assert(!isnan(wp[i])); } #endif float max = -INFINITY; - ggml_vec_max_f32(nc, &max, sp); + ggml_vec_max_f32(nc, &max, wp); ggml_float sum = 0.0; uint16_t scvt; for (int i = 0; i < nc; i++) { - if (sp[i] == -INFINITY) { + if (wp[i] == -INFINITY) { dp[i] = 0.0f; } else { - // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); - ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max); + // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); sum += (ggml_float)val; @@ -10622,11 +10660,12 @@ static void ggml_compute_forward_soft_max_f32( static void ggml_compute_forward_soft_max( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - struct ggml_tensor * dst) { + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_soft_max_f32(params, src0, dst); + ggml_compute_forward_soft_max_f32(params, src0, src1, dst); } break; default: { @@ -13863,7 +13902,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX: { - ggml_compute_forward_soft_max(params, tensor->src[0], tensor); + ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_SOFT_MAX_BACK: { @@ -15899,6 +15938,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } } break; + case GGML_OP_SOFT_MAX: + { + n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0])); + + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } break; case GGML_OP_CONV_TRANSPOSE_1D: { GGML_ASSERT(node->src[0]->ne[3] == 1); |