summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
blob: f9cde064887a8c1a0680431f2e93b8d363ab950a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
    const uint y1_idx = i * QUANT_K + y_offset;
    const uint y2_idx = y1_idx + 128;

    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
        vec2 d = vec2(data_a[ib0 + i].d);
        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

        const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
        const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
        const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];

        const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
        const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
        const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
        const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));

        const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
        const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
        const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
        const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
        const FLOAT_TYPE sc4 = scale8_f.x;
        const FLOAT_TYPE sc5 = scale8_f.y;
        const FLOAT_TYPE sc6 = scale8_f.z;
        const FLOAT_TYPE sc7 = scale8_f.w;

        const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
        const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];

        const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
        const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
        const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
        const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;

        const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
        const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
        const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
        const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));

        const FLOAT_TYPE q4_0  = qs0_lo4.x;
        const FLOAT_TYPE q4_1  = qs0_lo4.y;
        const FLOAT_TYPE q4_2  = qs0_lo4.z;
        const FLOAT_TYPE q4_3  = qs0_lo4.w;
        const FLOAT_TYPE q4_4  = qs0_hi4.x;
        const FLOAT_TYPE q4_5  = qs0_hi4.y;
        const FLOAT_TYPE q4_6  = qs0_hi4.z;
        const FLOAT_TYPE q4_7  = qs0_hi4.w;
        const FLOAT_TYPE q4_8  = qs64_lo4.x;
        const FLOAT_TYPE q4_9  = qs64_lo4.y;
        const FLOAT_TYPE q4_10 = qs64_lo4.z;
        const FLOAT_TYPE q4_11 = qs64_lo4.w;
        const FLOAT_TYPE q4_12 = qs64_hi4.x;
        const FLOAT_TYPE q4_13 = qs64_hi4.y;
        const FLOAT_TYPE q4_14 = qs64_hi4.z;
        const FLOAT_TYPE q4_15 = qs64_hi4.w;

        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
            vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]);
            vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
            vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]);
            vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);

            const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
            const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
            const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
            const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
            const FLOAT_TYPE smin =
                fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
                fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
                fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
                fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
            temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
        }
    }
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
    uint a_offset, b_offset, d_offset;
    get_offsets(a_offset, b_offset, d_offset);

    const uint num_blocks_per_row = p.ncols / QUANT_K;

    // 16 threads are used to process each block
    const uint it_size = gl_WorkGroupSize.x/16;
    const uint tid = gl_LocalInvocationID.x;
    const uint itid = tid%16;  // 0...15
    const uint ix = tid/16;

    const uint il = itid/4;                         // 0...3
    const uint ir = itid - 4*il;                    // 0...3
    const uint n =  4;

    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
    const uint v_in = il % 2;

    const uint l0 = n * (2 * ir + v_in);            // 0...15
    const uint q_offset = 32*v_im + l0;
    const uint y_offset = 64*v_im + l0;

    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
            temp[j][i] = FLOAT_TYPE(0);
        }
    }

    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
        calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);

    reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

    // do NUM_ROWS at a time, unless there aren't enough remaining rows
    if (first_row + NUM_ROWS <= p.stride_d) {
        compute_outputs(first_row, NUM_ROWS);
    } else {
        if (first_row >= p.stride_d) {
            return;
        }
        compute_outputs(first_row, p.stride_d - first_row);
    }
}