blob: 11e07e66bc9374a21b35d956fe4e9e56692bd95a (
plain)
| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
 | #version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
    const uint tid = gl_LocalInvocationID.x % 64;
    const uint il  = tid/32;
    const uint ir  = tid%32;
    const uint ib = 32*i + ir;
    if (ib >= p.nel / 32) {
        return;
    }
    const uint b_idx = 1024*i + 32*ir + 8*il;
    const float d = float(data_a[ib].d);
    const float dm = -8.0f * d;
    const uint q_idx = 8*il;
    [[unroll]] for (uint l = 0; l < 8; ++l) {
        data_b[b_idx + l +  0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm);
        data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >>  4) + dm);
    }
}
 |