summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c116
1 files changed, 78 insertions, 38 deletions
diff --git a/ggml.c b/ggml.c
index 264cfd70..e94024c6 100644
--- a/ggml.c
+++ b/ggml.c
@@ -5096,16 +5096,28 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
+ struct ggml_tensor * pos,
float scale,
+ float max_bias,
bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a));
+
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
- GGML_ASSERT(mask->ne[2] == 1);
- GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
}
+ if (pos) {
+ GGML_ASSERT(ggml_is_vector(pos));
+ GGML_ASSERT(pos->type == GGML_TYPE_F32);
+ GGML_ASSERT(pos->ne[0] == a->ne[0]);
+ }
+
+ if (max_bias > 0.0f) {
+ GGML_ASSERT(pos);
+ }
+
bool is_node = false;
if (a->grad) {
@@ -5114,13 +5126,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- float params[] = { scale };
+ float params[] = { scale, max_bias };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = mask;
+ result->src[2] = pos;
return result;
}
@@ -5128,21 +5141,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
+ return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
}
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
- return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+ return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
}
struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
- float scale) {
- return ggml_soft_max_impl(ctx, a, mask, scale, false);
+ struct ggml_tensor * pos,
+ float scale,
+ float max_bias) {
+ return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
}
// ggml_soft_max_back
@@ -11495,6 +11510,7 @@ static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
+ const struct ggml_tensor * src2,
struct ggml_tensor * dst) {
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
@@ -11503,16 +11519,29 @@ static void ggml_compute_forward_soft_max_f32(
return;
}
- float scale = 1.0f;
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
const int nth = params->nth;
+ GGML_TENSOR_UNARY_OP_LOCALS
+
const int64_t ne11 = src1 ? src1->ne[1] : 1;
+ // TODO: is this supposed to be ceil instead of floor?
+ // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+ const uint32_t n_head_kv = ne02;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
@@ -11525,6 +11554,9 @@ static void ggml_compute_forward_soft_max_f32(
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+ // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
+ float * pos = src2 ? (float *) src2->data : src0->data;
+
for (int i1 = ir0; i1 < ir1; i1++) {
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
@@ -11538,6 +11570,16 @@ static void ggml_compute_forward_soft_max_f32(
ggml_vec_acc_f32(nc, wp, mp);
}
+ // ALiBi bias
+ if (max_bias > 0.0f) {
+ const uint32_t h = (i1/ne01)%ne02; // head
+ const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
+
+ for (int i = 0; i < nc; i++) {
+ wp[i] = wp[i] + slope*pos[i];
+ }
+ }
+
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
@@ -11582,11 +11624,12 @@ static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
+ const struct ggml_tensor * src2,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
- ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
+ ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
} break;
default:
{
@@ -11730,22 +11773,20 @@ static void ggml_compute_forward_alibi_f32(
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
- for (int64_t i = 0; i < ne0; i++) {
- for (int64_t j = 0; j < ne1; j++) {
- for (int64_t k = 0; k < ne2_ne3; k++) {
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
-
- // TODO: k*nb2 or k*nb3
+ for (int64_t k = 0; k < ne2_ne3; k++) {
+ // TODO: k*nb2 or k*nb3
+ float m_k;
- float m_k;
-
- if (k < n_heads_log2_floor) {
- m_k = powf(m0, k + 1);
- } else {
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
+ if (k < n_heads_log2_floor) {
+ m_k = powf(m0, k + 1);
+ } else {
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+ }
+ for (int64_t i = 0; i < ne0; i++) {
+ for (int64_t j = 0; j < ne1; j++) {
+ float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
pdst[0] = i * m_k + src[0];
}
}
@@ -11790,21 +11831,20 @@ static void ggml_compute_forward_alibi_f16(
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
- for (int i = 0; i < ne0; i++) {
- for (int j = 0; j < ne1; j++) {
- for (int k = 0; k < ne2_ne3; k++) {
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
-
- // TODO: k*nb2 or k*nb3
+ for (int k = 0; k < ne2_ne3; k++) {
+ // TODO: k*nb2 or k*nb3
+ float m_k;
- float m_k;
+ if (k < n_heads_log2_floor) {
+ m_k = powf(m0, k + 1);
+ } else {
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+ }
- if (k < n_heads_log2_floor) {
- m_k = powf(m0, k + 1);
- } else {
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
- }
+ for (int i = 0; i < ne0; i++) {
+ for (int j = 0; j < ne1; j++) {
+ ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
// we return F32
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
@@ -15116,7 +15156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX:
{
- ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{