summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders')
-rw-r--r--ggml/src/vulkan-shaders/CMakeLists.txt2
-rw-r--r--ggml/src/vulkan-shaders/add.comp6
-rw-r--r--ggml/src/vulkan-shaders/clamp.comp8
-rw-r--r--ggml/src/vulkan-shaders/concat.comp35
-rw-r--r--ggml/src/vulkan-shaders/copy.comp8
-rw-r--r--ggml/src/vulkan-shaders/div.comp6
-rw-r--r--ggml/src/vulkan-shaders/gelu.comp2
-rw-r--r--ggml/src/vulkan-shaders/gelu_quick.comp23
-rw-r--r--ggml/src/vulkan-shaders/generic_binary_head.comp6
-rw-r--r--ggml/src/vulkan-shaders/generic_unary_head.comp4
-rw-r--r--ggml/src/vulkan-shaders/group_norm.comp66
-rw-r--r--ggml/src/vulkan-shaders/im2col.comp57
-rw-r--r--ggml/src/vulkan-shaders/leaky_relu.comp22
-rw-r--r--ggml/src/vulkan-shaders/mul.comp6
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec.comp13
-rw-r--r--ggml/src/vulkan-shaders/norm.comp2
-rw-r--r--ggml/src/vulkan-shaders/pad.comp26
-rw-r--r--ggml/src/vulkan-shaders/relu.comp2
-rw-r--r--ggml/src/vulkan-shaders/rms_norm.comp2
-rw-r--r--ggml/src/vulkan-shaders/scale.comp6
-rw-r--r--ggml/src/vulkan-shaders/silu.comp2
-rw-r--r--ggml/src/vulkan-shaders/soft_max.comp2
-rw-r--r--ggml/src/vulkan-shaders/square.comp8
-rw-r--r--ggml/src/vulkan-shaders/sum_rows.comp2
-rw-r--r--ggml/src/vulkan-shaders/tanh.comp21
-rw-r--r--ggml/src/vulkan-shaders/timestep_embedding.comp41
-rw-r--r--ggml/src/vulkan-shaders/types.comp4
-rw-r--r--ggml/src/vulkan-shaders/upscale.comp36
-rw-r--r--ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp74
29 files changed, 452 insertions, 40 deletions
diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/vulkan-shaders/CMakeLists.txt
index 41551e00..10075db3 100644
--- a/ggml/src/vulkan-shaders/CMakeLists.txt
+++ b/ggml/src/vulkan-shaders/CMakeLists.txt
@@ -1,5 +1,7 @@
+find_package (Threads REQUIRED)
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_11)
+target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
diff --git a/ggml/src/vulkan-shaders/add.comp b/ggml/src/vulkan-shaders/add.comp
index 8475b011..3974845d 100644
--- a/ggml/src/vulkan-shaders/add.comp
+++ b/ggml/src/vulkan-shaders/add.comp
@@ -4,9 +4,11 @@
#include "generic_binary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
}
diff --git a/ggml/src/vulkan-shaders/clamp.comp b/ggml/src/vulkan-shaders/clamp.comp
index ca272e22..7071302a 100644
--- a/ggml/src/vulkan-shaders/clamp.comp
+++ b/ggml/src/vulkan-shaders/clamp.comp
@@ -4,10 +4,12 @@
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
diff --git a/ggml/src/vulkan-shaders/concat.comp b/ggml/src/vulkan-shaders/concat.comp
new file mode 100644
index 00000000..08ab5514
--- /dev/null
+++ b/ggml/src/vulkan-shaders/concat.comp
@@ -0,0 +1,35 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+ const int dim = p.param3;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
+ const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
+ const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
+ const uint i2_offset = i2*p.ne21*p.ne20;
+ const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
+
+ uint o[4] = {0, 0, 0, 0};
+ o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
+
+ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
+ const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
+ const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
+
+ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
+
+#ifndef OPTIMIZATION_ERROR_WORKAROUND
+ data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
+#else
+ data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
+#endif
+}
diff --git a/ggml/src/vulkan-shaders/copy.comp b/ggml/src/vulkan-shaders/copy.comp
index efb55876..c26917c0 100644
--- a/ggml/src/vulkan-shaders/copy.comp
+++ b/ggml/src/vulkan-shaders/copy.comp
@@ -4,13 +4,15 @@
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
#else
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
+ data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
#endif
}
diff --git a/ggml/src/vulkan-shaders/div.comp b/ggml/src/vulkan-shaders/div.comp
index 8ee4bfc7..8cfce58b 100644
--- a/ggml/src/vulkan-shaders/div.comp
+++ b/ggml/src/vulkan-shaders/div.comp
@@ -4,9 +4,11 @@
#include "generic_binary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
}
diff --git a/ggml/src/vulkan-shaders/gelu.comp b/ggml/src/vulkan-shaders/gelu.comp
index 9fe807cc..4cc7a68c 100644
--- a/ggml/src/vulkan-shaders/gelu.comp
+++ b/ggml/src/vulkan-shaders/gelu.comp
@@ -13,7 +13,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
- const uint i = gl_GlobalInvocationID.x;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
diff --git a/ggml/src/vulkan-shaders/gelu_quick.comp b/ggml/src/vulkan-shaders/gelu_quick.comp
new file mode 100644
index 00000000..e6e6fcfd
--- /dev/null
+++ b/ggml/src/vulkan-shaders/gelu_quick.comp
@@ -0,0 +1,23 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const float GELU_QUICK_COEF = -1.702f;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float x = float(data_a[i]);
+ data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
+}
diff --git a/ggml/src/vulkan-shaders/generic_binary_head.comp b/ggml/src/vulkan-shaders/generic_binary_head.comp
index ab45d256..b6beaff1 100644
--- a/ggml/src/vulkan-shaders/generic_binary_head.comp
+++ b/ggml/src/vulkan-shaders/generic_binary_head.comp
@@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint d_offset;
- float param1; float param2;
+ float param1; float param2; int param3;
} p;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
@@ -16,6 +16,10 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+uint get_idx() {
+ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+}
+
uint src0_idx(uint idx) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
diff --git a/ggml/src/vulkan-shaders/generic_unary_head.comp b/ggml/src/vulkan-shaders/generic_unary_head.comp
index de08de7c..eacdefc7 100644
--- a/ggml/src/vulkan-shaders/generic_unary_head.comp
+++ b/ggml/src/vulkan-shaders/generic_unary_head.comp
@@ -14,6 +14,10 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+uint get_idx() {
+ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+}
+
uint src0_idx(uint idx) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
diff --git a/ggml/src/vulkan-shaders/group_norm.comp b/ggml/src/vulkan-shaders/group_norm.comp
new file mode 100644
index 00000000..5ad9b28d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/group_norm.comp
@@ -0,0 +1,66 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared float tmp[BLOCK_SIZE];
+
+void main() {
+ const uint group_size = p.KX;
+ const float eps = p.param1;
+
+ const uint tid = gl_LocalInvocationID.x;
+ const uint start = gl_WorkGroupID.x * group_size + tid;
+ const uint end = start + group_size;
+
+ tmp[tid] = 0.0f;
+
+ // Calculate mean
+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+ tmp[tid] += float(data_a[col]);
+ }
+
+ // tmp up partial tmps and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+
+ const float mean = tmp[0] / group_size;
+ barrier();
+ tmp[tid] = 0.0f;
+
+ // Calculate variance
+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+ const float xi = float(data_a[col]) - mean;
+ data_d[col] = D_TYPE(xi);
+ tmp[tid] += xi * xi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+
+ const float variance = tmp[0] / group_size;
+ const float scale = inversesqrt(variance + eps);
+
+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+ data_d[col] *= D_TYPE(scale);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/im2col.comp b/ggml/src/vulkan-shaders/im2col.comp
new file mode 100644
index 00000000..4d48610a
--- /dev/null
+++ b/ggml/src/vulkan-shaders/im2col.comp
@@ -0,0 +1,57 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint batch_offset; uint offset_delta;
+ uint IC;
+ uint IW; uint IH;
+ uint OW; uint OH;
+ uint KW; uint KH;
+ uint pelements;
+ uint CHW;
+ int s0; int s1;
+ int p0; int p1;
+ int d0; int d1;
+} p;
+
+#include "types.comp"
+
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.x;
+ if (i >= p.pelements) {
+ return;
+ }
+
+ const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+ const uint kx = i / ksize;
+ const uint kd = kx * ksize;
+ const uint ky = (i - kd) / p.OW;
+ const uint ix = i % p.OW;
+
+ const uint oh = gl_GlobalInvocationID.y;
+ const uint batch = gl_GlobalInvocationID.z / p.IC;
+ const uint ic = gl_GlobalInvocationID.z % p.IC;
+
+ const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
+ const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
+
+ const uint offset_dst =
+ ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
+ (ic * (p.KW * p.KH) + ky * p.KW + kx);
+
+ if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
+ data_d[offset_dst] = D_TYPE(0.0f);
+ } else {
+ const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
+ data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/leaky_relu.comp b/ggml/src/vulkan-shaders/leaky_relu.comp
new file mode 100644
index 00000000..d90a99ae
--- /dev/null
+++ b/ggml/src/vulkan-shaders/leaky_relu.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float val = float(data_a[i]);
+ data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
+}
diff --git a/ggml/src/vulkan-shaders/mul.comp b/ggml/src/vulkan-shaders/mul.comp
index bbb0aa1d..bfb61c92 100644
--- a/ggml/src/vulkan-shaders/mul.comp
+++ b/ggml/src/vulkan-shaders/mul.comp
@@ -4,9 +4,11 @@
#include "generic_binary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec.comp b/ggml/src/vulkan-shaders/mul_mat_vec.comp
index 15d2a806..46a6369b 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec.comp
@@ -16,6 +16,13 @@ void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
const uint tid = gl_LocalInvocationID.x;
+ // There are not enough cols to use all threads
+ if (tid >= p.ncols) {
+ return;
+ }
+
+ const uint block_size = min(p.ncols, BLOCK_SIZE);
+
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@@ -23,8 +30,8 @@ void main() {
tmp[tid] = FLOAT_TYPE(0.0f);
- [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
- const uint col = i*BLOCK_SIZE + 2*tid;
+ [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
+ const uint col = i*block_size + 2*tid;
const uint ib = (row*p.ncols + col)/QUANT_K; // block index
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
@@ -38,7 +45,7 @@ void main() {
// sum up partial sums and write back result
barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
diff --git a/ggml/src/vulkan-shaders/norm.comp b/ggml/src/vulkan-shaders/norm.comp
index 803dbdcb..6627a50b 100644
--- a/ggml/src/vulkan-shaders/norm.comp
+++ b/ggml/src/vulkan-shaders/norm.comp
@@ -14,7 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared vec2 sum[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.x;
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = vec2(0.0f, 0.0f);
diff --git a/ggml/src/vulkan-shaders/pad.comp b/ggml/src/vulkan-shaders/pad.comp
new file mode 100644
index 00000000..a465cd52
--- /dev/null
+++ b/ggml/src/vulkan-shaders/pad.comp
@@ -0,0 +1,26 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i3 = idx / (p.ne12*p.ne11*p.ne10);
+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
+ const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10);
+ const uint i2_offset = i2*p.ne11*p.ne10;
+ const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
+
+ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
+ const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
+
+ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
+
+ data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
+}
diff --git a/ggml/src/vulkan-shaders/relu.comp b/ggml/src/vulkan-shaders/relu.comp
index 7e5baa5b..52a19b62 100644
--- a/ggml/src/vulkan-shaders/relu.comp
+++ b/ggml/src/vulkan-shaders/relu.comp
@@ -11,7 +11,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
- const uint i = gl_GlobalInvocationID.x;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/vulkan-shaders/rms_norm.comp
index cfd08d34..b554400b 100644
--- a/ggml/src/vulkan-shaders/rms_norm.comp
+++ b/ggml/src/vulkan-shaders/rms_norm.comp
@@ -14,7 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.x;
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp
index 510cb723..5cd2f668 100644
--- a/ggml/src/vulkan-shaders/scale.comp
+++ b/ggml/src/vulkan-shaders/scale.comp
@@ -4,9 +4,11 @@
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1));
}
diff --git a/ggml/src/vulkan-shaders/silu.comp b/ggml/src/vulkan-shaders/silu.comp
index 15920f06..4d36f88e 100644
--- a/ggml/src/vulkan-shaders/silu.comp
+++ b/ggml/src/vulkan-shaders/silu.comp
@@ -11,7 +11,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
- const uint i = gl_GlobalInvocationID.x;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp
index 1b8419c7..0bd51eca 100644
--- a/ggml/src/vulkan-shaders/soft_max.comp
+++ b/ggml/src/vulkan-shaders/soft_max.comp
@@ -28,7 +28,7 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
void main() {
const uint tid = gl_LocalInvocationID.x;
- const uint rowx = gl_WorkGroupID.x;
+ const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint rowy = rowx % p.KY;
float slope = 1.0f;
diff --git a/ggml/src/vulkan-shaders/square.comp b/ggml/src/vulkan-shaders/square.comp
index 8dd19333..1fa118c9 100644
--- a/ggml/src/vulkan-shaders/square.comp
+++ b/ggml/src/vulkan-shaders/square.comp
@@ -4,10 +4,12 @@
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
}
diff --git a/ggml/src/vulkan-shaders/sum_rows.comp b/ggml/src/vulkan-shaders/sum_rows.comp
index ce2f1e2f..961e5ffa 100644
--- a/ggml/src/vulkan-shaders/sum_rows.comp
+++ b/ggml/src/vulkan-shaders/sum_rows.comp
@@ -14,7 +14,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.x;
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
tmp[col] = FLOAT_TYPE(0.0f);
diff --git a/ggml/src/vulkan-shaders/tanh.comp b/ggml/src/vulkan-shaders/tanh.comp
new file mode 100644
index 00000000..74630dc7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/tanh.comp
@@ -0,0 +1,21 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ data_d[i] = D_TYPE(tanh(data_a[i]));
+}
diff --git a/ggml/src/vulkan-shaders/timestep_embedding.comp b/ggml/src/vulkan-shaders/timestep_embedding.comp
new file mode 100644
index 00000000..79e065a9
--- /dev/null
+++ b/ggml/src/vulkan-shaders/timestep_embedding.comp
@@ -0,0 +1,41 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint nb1;
+ uint dim;
+ uint max_period;
+} p;
+
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.y;
+ const uint j = gl_GlobalInvocationID.x;
+ const uint d_offset = i * p.nb1;
+
+ if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
+ data_d[d_offset + p.dim] = 0.f;
+ }
+
+ const uint half_dim = p.dim / 2;
+ if (j >= half_dim) {
+ return;
+ }
+
+ const float timestep = float(data_a[i]);
+ const float freq = float(exp(-log(p.max_period) * j / half_dim));
+ const float arg = timestep * freq;
+ data_d[d_offset + j] = D_TYPE(cos(arg));
+ data_d[d_offset + j + half_dim] = D_TYPE(sin(arg));
+}
diff --git a/ggml/src/vulkan-shaders/types.comp b/ggml/src/vulkan-shaders/types.comp
index d24c172c..21dce72f 100644
--- a/ggml/src/vulkan-shaders/types.comp
+++ b/ggml/src/vulkan-shaders/types.comp
@@ -6,7 +6,7 @@
#define QUANT_K 1
#define QUANT_R 1
-#ifndef LOAD_VEC_A
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float
#elif LOAD_VEC_A == 4
#define A_TYPE vec4
@@ -19,7 +19,7 @@
#define QUANT_K 1
#define QUANT_R 1
-#ifndef LOAD_VEC_A
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float16_t
#elif LOAD_VEC_A == 4
#define A_TYPE f16vec4
diff --git a/ggml/src/vulkan-shaders/upscale.comp b/ggml/src/vulkan-shaders/upscale.comp
new file mode 100644
index 00000000..511a086e
--- /dev/null
+++ b/ggml/src/vulkan-shaders/upscale.comp
@@ -0,0 +1,36 @@
+#version 450
+
+layout (push_constant) uniform parameter
+{
+ uint ne; uint d_offset;
+ uint nb00; uint nb01; uint nb02; uint nb03;
+ uint ne10; uint ne11; uint ne12; uint ne13;
+ float sf0; float sf1; float sf2; float sf3;
+} p;
+
+#include "types.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i10 = idx % p.ne10;
+ const uint i11 = (idx / p.ne10) % p.ne11;
+ const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
+ const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
+
+ const uint i00 = uint(i10 / p.sf0);
+ const uint i01 = uint(i11 / p.sf1);
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+
+ data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
+}
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
index c5be3754..a792e203 100644
--- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -22,6 +22,7 @@
#ifdef _WIN32
#include <windows.h>
#include <direct.h> // For _mkdir on Windows
+ #include <algorithm> // For std::replace on w64devkit
#else
#include <unistd.h>
#include <sys/wait.h>
@@ -179,11 +180,7 @@ bool string_ends_with(const std::string& str, const std::string& suffix) {
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
}
-#ifdef _WIN32
- static const char path_separator = '\\';
-#else
- static const char path_separator = '/';
-#endif
+static const char path_separator = '/';
std::string join_paths(const std::string& path1, const std::string& path2) {
return path1 + path_separator + path2;
@@ -198,7 +195,11 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname);
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
+ #ifdef _WIN32
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
+ #else
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
+ #endif
for (const auto& define : defines) {
cmd.push_back("-D" + define.first + "=" + define.second);
}
@@ -269,9 +270,12 @@ void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmu
for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ // For unaligned, load one at a time for f32/f16, or two at a time for quants
+ std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
+ // For aligned matmul loads
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
+ string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
@@ -341,6 +345,9 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
@@ -357,6 +364,9 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
+ }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -383,14 +393,41 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
}));
tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -424,6 +461,17 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
}
void write_output_files() {
@@ -435,10 +483,16 @@ void write_output_files() {
for (const auto& pair : shader_fnames) {
const std::string& name = pair.first;
- const std::string& path = pair.second;
+ #ifdef _WIN32
+ std::string path = pair.second;
+ std::replace(path.begin(), path.end(), '/', '\\' );
+ #else
+ const std::string& path = pair.second;
+ #endif
+
FILE* spv = fopen(path.c_str(), "rb");
if (!spv) {
- std::cerr << "Error opening SPIR-V file: " << path << "\n";
+ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
continue;
}
@@ -450,7 +504,7 @@ void write_output_files() {
size_t read_size = fread(data.data(), 1, size, spv);
fclose(spv);
if (read_size != size) {
- std::cerr << "Error reading SPIR-V file: " << path << "\n";
+ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
continue;
}