summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/roll.comp
blob: b9abe8dedcf865636b3f24e821c91dc09d42aa09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#version 450

#include "types.comp"
#include "generic_unary_head.comp"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

uint wrap_idx(int i, uint ne) {
    if (i < 0) {
        return i + ne;
    } else if (i >= ne) {
        return i - ne;
    }
    return i;
}

void main() {
    const uint idx = get_idx();
    if (idx >= p.ne) {
        return;
    }

    const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
    const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
    const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
    const uint i2_offset = i2*p.ne11*p.ne10;
    const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;

    const uint p1 = floatBitsToUint(p.param1);
    const uint p2 = floatBitsToUint(p.param2);
    const int s0 = int(p1 >> 16)    - 0x8000;
    const int s1 = int(p1 & 0xFFFF) - 0x8000;
    const int s2 = int(p2 >> 16)    - 0x8000;
    const int s3 = int(p2 & 0xFFFF) - 0x8000;

    const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
    const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
    const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
    const uint i03 = wrap_idx(int(i3) - s3, p.ne13);

    const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
    const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;

    data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
}