summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/wkv7.comp
blob: 88c1c02b32b8c3fa7a027821a1c9b7857164772e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#version 450

#extension GL_EXT_control_flow_attributes : require

#define BLOCK_SIZE 64
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform Parameters {
    uint B;
    uint T;
    uint C;
    uint H;
};

layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };

shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];

void main() {
    const uint head_size = BLOCK_SIZE;
    const uint batch_id = gl_WorkGroupID.x / H;
    const uint head_id = gl_WorkGroupID.x % H;
    const uint tid = gl_LocalInvocationID.x;

    const uint state_size = C * head_size;
    const uint n_seq_tokens = T / B;

    if (batch_id >= B || head_id >= H) {
        return;
    }

    A_TYPE state[BLOCK_SIZE];
    [[unroll]] for (uint i = 0; i < head_size; i++) {
        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
                          + tid * head_size + i];
    }

    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;

    for (uint t = start_t; t < end_t; t += C) {
        barrier();
        _r[tid] = r[t];
        _w[tid] = w[t];
        _k[tid] = k[t];
        _a[tid] = a[t];
        _b[tid] = b[t];
        barrier();

        A_TYPE sa = 0.0;
        [[unroll]] for (uint j = 0; j < head_size; j += 4) {
            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
            vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
            sa += dot(s_vec, a_vec);
        }

        const A_TYPE v_val = v[t];
        A_TYPE y = 0.0;

        [[unroll]] for (uint j = 0; j < head_size; j += 4) {
            vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
            vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
            vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
            vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);

            vec4 kv = k_vec * v_val;
            s_vec = s_vec * w_vec + kv + sa * b_vec;
            y += dot(r_vec, s_vec);

            state[j] = s_vec.x;
            state[j+1] = s_vec.y;
            state[j+2] = s_vec.z;
            state[j+3] = s_vec.w;
        }

        dst[t] = y;
    }

    [[unroll]] for (uint i = 0; i < head_size; i++) {
        dst[T * C + batch_id * state_size + head_id * head_size * head_size
            + tid * head_size + i] = state[i];
    }
}