summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/glu_main.comp
blob: 85cf65a9ecac8d8f1076ec36200a9090581e060b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void main() {
    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;

    if (i >= p.N) {
        return;
    }

    const uint row = i / p.ne20;
    const uint col = i - row * p.ne20;

    if (p.mode == 0) {
        // Default
        const uint offset = p.ne00 / 2;
        const uint idx = row * p.ne00 + col;

        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
    } else if (p.mode == 1) {
        // Swapped
        const uint offset = p.ne00 / 2;
        const uint idx = row * p.ne00 + col;

        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
    } else {
        // Split
        const uint idx = row * p.ne00 + col;

        data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
    }
}