summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
blob: 423ceb8a3df463b917e3f828e3d163cfcfab0c16 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;

void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
    const uint y_idx = i * QUANT_K + y_offset;

    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
        csel ^= 1;

        if (!all_threads) { // when we don't have enough blocks to use all threads
            if (i < num_blocks_per_row) {
                const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
                sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
                sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
            }
            barrier();

            if (i >= num_blocks_per_row)
                continue;
        } else {
            const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
            sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
            sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
            barrier();
        }

        const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
        const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
        const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
        const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
        const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));

        vec2 d = vec2(data_a[ib0 + i].d);
        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
            vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
            vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
            vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
            vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
            vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
            vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
            vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
            vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);

            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
            [[unroll]] for (int l = 0; l < 2; ++l) {
                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[csel][ix][    8*v_im] * qs_u32_0[l  ],
                       fma(FLOAT_TYPE(b16[l]),  sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
                       fma(FLOAT_TYPE(b32[l]),  sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l  ],
                       fma(FLOAT_TYPE(b48[l]),  sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
                       fma(FLOAT_TYPE(b64[l]),  sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l  ],
                       fma(FLOAT_TYPE(b80[l]),  sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
                       fma(FLOAT_TYPE(b96[l]),  sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l  ],
                       fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[csel][ix][    8*v_im],
                       fma(FLOAT_TYPE(b16[l]),  sccache2[csel][ix][1 + 8*v_im],
                       fma(FLOAT_TYPE(b32[l]),  sccache2[csel][ix][2 + 8*v_im],
                       fma(FLOAT_TYPE(b48[l]),  sccache2[csel][ix][3 + 8*v_im],
                       fma(FLOAT_TYPE(b64[l]),  sccache2[csel][ix][4 + 8*v_im],
                       fma(FLOAT_TYPE(b80[l]),  sccache2[csel][ix][5 + 8*v_im],
                       fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ix][6 + 8*v_im],
                       fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
            }
            temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
        }
    }
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
    uint a_offset, b_offset, d_offset;
    get_offsets(a_offset, b_offset, d_offset);

    const uint num_blocks_per_row = p.ncols / QUANT_K;

    // 16 threads are used to process each block
    const uint it_size = gl_WorkGroupSize.x/16;
    const uint tid = gl_LocalInvocationID.x;
    const uint itid = tid%16;  // 0...15
    const uint ix = tid/16;

    const uint v_im = itid/8;                                // 0 or 1. 0 computes 0..., 1 computes 128...
    const uint v_in = itid - 8*v_im;                         // 0...7

    const uint l0 = 2*v_in;                                  // 0...15
    const uint q_offset = 32*v_im + l0;
    const uint y_offset = 128*v_im + l0;

    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
            temp[j][i] = FLOAT_TYPE(0);
        }
    }

    const uint nbr_par_th = num_blocks_per_row%it_size;
    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
    uint i0 = 0;
    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
        calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
    calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);

    reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

    // do NUM_ROWS at a time, unless there aren't enough remaining rows
    if (first_row + NUM_ROWS <= p.stride_d) {
        compute_outputs(first_row, NUM_ROWS);
    } else {
        if (first_row >= p.stride_d) {
            return;
        }
        compute_outputs(first_row, p.stride_d - first_row);
    }
}