summaryrefslogtreecommitdiff
path: root/kompute-shaders/op_mul_mat_q6_k.comp
blob: c9baebdf4baac6a0006d869449a85bf522ee6590 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#version 450

#include "common.comp"

#define SIZE_OF_BLOCK sizeof_block_q6_k

layout(local_size_x_id = 0) in;
layout(local_size_y_id = 1) in;
layout(local_size_z = 1) in;

layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };

layout (push_constant) uniform parameter {
    uint inAOff;
    uint inBOff;
    uint outOff;
    int ne00;
    int ne10;
    int ne0;
    int ne1;
    int ne01;
    int gqa;
} pcs;

void main() {
    const uint8_t kmask1 = uint8_t(0x03);
    const uint8_t kmask2 = uint8_t(0x0C);
    const uint8_t kmask3 = uint8_t(0x30);
    const uint8_t kmask4 = uint8_t(0xC0);

    const uint nb = pcs.ne00/QK_K;

    const uint r0 = gl_WorkGroupID.x;
    const uint r1 = gl_WorkGroupID.y;
    const uint r2 = gl_WorkGroupID.z;

    const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
    const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
    const uint x = row * nb + offset0; // Based from inA without base offset
    const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB

    float sumf = 0;

    // bits of invocation ID for gl_SubgroupSize=32:
    //  x   x   x   x   x
    //  4   3   2   1   0
    // (     tid     ) ix
    //  ip (   il    )

    const uint block_stride = gl_SubgroupSize / 16;         // number of blocks each subgroup processes
    const uint tid  = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
    const uint ix   = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
    const uint ip   = tid/8;        // first or second half of block (0 or 1)
    const uint il   = tid%8;        // each half has 8 parts, one per scale
    const uint n    = 4;            // 4 scales at a time (and 4 sums)
    const uint l0   = n*il;         // offset into half-block, 0..28
    const uint is   = 8*ip + l0/16; // 0, 1, 8, 9

    const uint y_offset = 128*ip + l0;
    const uint q_offset_l = 64*ip + l0;
    const uint q_offset_h = 32*ip + l0;

    for (uint i = ix; i < nb; i += block_stride) {

        const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;

        const uint qlIndex = q_offset_l;
        const uint q2Index = qlIndex + QK_K/8;
        const uint qhIndex = q_offset_h;
        const uint y = yy + i * QK_K + y_offset;

        float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
        for (uint l = 0; l < n; ++l) {
            const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
            const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
            const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];

            sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
            sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
            sums[2] += inB[y+l+64] * (int8_t((currentQ1  >> 4) | ((currentQh & kmask3) << 0)) - 32);
            sums[3] += inB[y+l+96] * (int8_t((currentQ2  >> 4) | ((currentQh & kmask4) >> 2)) - 32);
        }

        float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
        sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
    }

    const float tot = subgroupAdd(sumf);
    if (subgroupElect()) {
        out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
    }
}