summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/conv_transpose_1d.comp
blob: b17b4e83eec4b6202bdf34a3ffd9adc6324f9323 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#version 450

#include "types.comp"

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};   // src0 - kernel:    [K, Cout, Cin]
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};   // src1 - input:     [L, Cin]
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};     // dst - result      [KL, Cout]

layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;

layout (push_constant) uniform parameter {
    uint32_t Cout;
    uint32_t Cin;
    uint32_t K;
    uint32_t L;
    uint32_t KL;

    uint32_t nb01;
    uint32_t nb02;
    uint32_t nb11;
    uint32_t nb1;

    int32_t s0;
} p;


uint32_t Cout_idx = gl_WorkGroupID.x;
const uint32_t bs = gl_WorkGroupSize.x;
uint32_t tid = gl_LocalInvocationID.x;
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
uint32_t tmp_len = bs*p.s0+p.K;
shared D_TYPE tmp[4096];

uint splitWork(uint workSize){
    return (bs + workSize -1) / bs;
}

void main(){
    for(uint32_t i = 0; i < splitWork(tmp_len); i++){
        uint32_t idx = i*bs+tid;
        if(idx < tmp_len){
            tmp[idx] = 0.0;
        }
    }

    uint32_t L_blocks = splitWork(p.L);
    for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
        if(L_block_id > 0){
            barrier();
            // Shift values in tmp to the current processing window
            for(int i = 0; i < splitWork(tmp_len); i++){
                uint32_t idx = i*bs+tid;
                if(idx >= bs*p.s0 && idx < tmp_len){
                    tmp[idx-bs*p.s0] = tmp[idx];
                    tmp[idx] = 0.0;
                }else if(idx >= p.K && idx < bs*p.s0){
                    tmp[idx] = 0.0;
                }
            }
        }
        barrier();

        // Save contributions of the block to tmp
        uint32_t L_idx = L_block_id*bs + tid;
        for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
            D_TYPE dp = 0.0;
            for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
                A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
                if(L_idx < p.L){
                    B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
                    dp = fma(elemKrn, elemInp, dp);
                }
            }
            tmp[tid*p.s0 + K_idx] += dp;
            barrier();
        }

        // Save the computed values except the last block that can have different size
        uint32_t KLb_idx = L_block_id*bs*p.s0;
        if(L_block_id < L_blocks-1){
            for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
                uint32_t sh_idx = p.s0*tid+s0_idx;
                uint32_t KL_idx = KLb_idx+sh_idx;
                if(KL_idx < p.KL){
                    data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
                }
            }
        }
    }

    for(uint32_t i = 0; i < splitWork(tmp_len); i++){
        uint32_t idx = i*bs+tid;
        uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
        if(KL_idx < p.KL){
            data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
        }
    }
}