summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/conv2d_dw.comp
blob: 938c74da5007476446fef614377a43504a502b40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#version 450

#include "types.comp"

layout (push_constant) uniform parameter
{
    uint ne;
    uint batches;
    uint channels;
    uint dst_w;
    uint dst_h;
    uint src_w;
    uint src_h;
    uint knl_w;
    uint knl_h;
    int stride_x;
    int stride_y;
    int pad_x;
    int pad_y;
    int dilation_x;
    int dilation_y;
} p;

layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
    uint i0 = idx / p.dst_w;
    uint dst_x = idx - i0 * p.dst_w;
    uint i1 = i0 / p.dst_h;
    uint dst_y = i0 - i1 * p.dst_h;
    uint n = i1 / p.channels;
    uint c = i1 - n * p.channels;

    uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
    uint knl_i = c * p.knl_h * p.knl_w;

    FLOAT_TYPE sum = 0.0;
    for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
        uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
        if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
            continue;
        }
        for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
            uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
            if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
                continue;
            }
            FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
            FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
            sum = fma(v, k, sum);
        }
    }
    return sum;
}

FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
    uint i0 = idx / p.channels;
    uint c = idx - i0 * p.channels;
    uint i1 = i0 / p.dst_w;
    uint dst_x = i0 - i1 * p.dst_w;
    uint n = i1 / p.dst_h;
    uint dst_y = i1 - n * p.dst_h;

    uint src_i = n * p.channels * p.src_h * p.src_w;
    uint src_row = p.src_w * p.channels;
    uint knl_row = p.knl_w * p.channels;

    FLOAT_TYPE sum = 0.0;
    for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
        uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
        if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
            continue;
        }
        for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
            uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
            if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
                continue;
            }
            FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
            FLOAT_TYPE k = FLOAT_TYPE(knl_data[        knl_y * knl_row + knl_x * p.channels + c]);
            sum = fma(v, k, sum);
        }
    }
    return sum;
}

void main() {
    uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
    if (idx >= p.ne) {
        return;
    }

    FLOAT_TYPE result =
#ifdef WHCN
        conv_2d_dw_whcn(idx);
#else
        conv_2d_dw_cwhn(idx);
#endif
    dst_data[idx] = D_TYPE(result);
}