summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/quantize_q8_1.comp
blob: e2e020fec2c6a9e247a50461b27e335327cb0a7a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#version 450

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require

layout (push_constant) uniform parameter
{
    uint ne;
} p;

#include "types.comp"

layout(constant_id = 0) const uint GROUP_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {vec4 data_a[];};
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};

shared float shmem[GROUP_SIZE];

void quantize() {
    const uint wgid = gl_WorkGroupID.x;
    const uint tid = gl_LocalInvocationID.x;

    // Each thread handles a vec4, so 8 threads handle a block
    const uint blocks_per_group = GROUP_SIZE / 8;

    const uint block_in_wg = tid / 8;

    const uint ib = wgid * blocks_per_group + block_in_wg;
    const uint iqs = tid % 8;

    if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
        return;
    }

    const uint a_idx = ib * 8 + iqs;

    vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
    const vec4 abs_vals = abs(vals);

    // Find absolute max for each block
    shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
    barrier();
    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
        if (iqs < s) {
            shmem[tid] = max(shmem[tid], shmem[tid + s]);
        }
        barrier();
    }

    const float amax = shmem[block_in_wg * 8];
    const float d = amax / 127.0;
    const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
    vals = round(vals * d_inv);
    data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
    barrier();

    // Calculate the sum for each block
    shmem[tid] = vals.x + vals.y + vals.z + vals.w;
    barrier();
    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
        if (iqs < s) {
            shmem[tid] += shmem[tid + s];
        }
        barrier();
    }
    if (iqs == 0) {
        const float sum = shmem[tid];

        data_b[ib].ds = f16vec2(vec2(d, sum * d));
    }
}

void main() {
    quantize();
}