summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/conv_transpose_1d.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/conv_transpose_1d.comp')
-rw-r--r--ggml/src/vulkan-shaders/conv_transpose_1d.comp98
1 files changed, 98 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/vulkan-shaders/conv_transpose_1d.comp
new file mode 100644
index 00000000..b17b4e83
--- /dev/null
+++ b/ggml/src/vulkan-shaders/conv_transpose_1d.comp
@@ -0,0 +1,98 @@
+#version 450
+
+#include "types.comp"
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
+
+layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
+
+layout (push_constant) uniform parameter {
+ uint32_t Cout;
+ uint32_t Cin;
+ uint32_t K;
+ uint32_t L;
+ uint32_t KL;
+
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb11;
+ uint32_t nb1;
+
+ int32_t s0;
+} p;
+
+
+uint32_t Cout_idx = gl_WorkGroupID.x;
+const uint32_t bs = gl_WorkGroupSize.x;
+uint32_t tid = gl_LocalInvocationID.x;
+// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
+uint32_t tmp_len = bs*p.s0+p.K;
+shared D_TYPE tmp[4096];
+
+uint splitWork(uint workSize){
+ return (bs + workSize -1) / bs;
+}
+
+void main(){
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
+ uint32_t idx = i*bs+tid;
+ if(idx < tmp_len){
+ tmp[idx] = 0.0;
+ }
+ }
+
+ uint32_t L_blocks = splitWork(p.L);
+ for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
+ if(L_block_id > 0){
+ barrier();
+ // Shift values in tmp to the current processing window
+ for(int i = 0; i < splitWork(tmp_len); i++){
+ uint32_t idx = i*bs+tid;
+ if(idx >= bs*p.s0 && idx < tmp_len){
+ tmp[idx-bs*p.s0] = tmp[idx];
+ tmp[idx] = 0.0;
+ }else if(idx >= p.K && idx < bs*p.s0){
+ tmp[idx] = 0.0;
+ }
+ }
+ }
+ barrier();
+
+ // Save contributions of the block to tmp
+ uint32_t L_idx = L_block_id*bs + tid;
+ for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
+ D_TYPE dp = 0.0;
+ for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
+ A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
+ if(L_idx < p.L){
+ B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
+ dp = fma(elemKrn, elemInp, dp);
+ }
+ }
+ tmp[tid*p.s0 + K_idx] += dp;
+ barrier();
+ }
+
+ // Save the computed values except the last block that can have different size
+ uint32_t KLb_idx = L_block_id*bs*p.s0;
+ if(L_block_id < L_blocks-1){
+ for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
+ uint32_t sh_idx = p.s0*tid+s0_idx;
+ uint32_t KL_idx = KLb_idx+sh_idx;
+ if(KL_idx < p.KL){
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
+ }
+ }
+ }
+ }
+
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
+ uint32_t idx = i*bs+tid;
+ uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
+ if(KL_idx < p.KL){
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
+ }
+ }
+}