summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders')
-rw-r--r--ggml/src/vulkan-shaders/CMakeLists.txt2
-rw-r--r--ggml/src/vulkan-shaders/conv2d_dw.comp105
-rw-r--r--ggml/src/vulkan-shaders/copy_to_quant.comp65
-rw-r--r--ggml/src/vulkan-shaders/flash_attn.comp8
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_base.comp15
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm1.comp8
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm2.comp9
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp56
-rw-r--r--ggml/src/vulkan-shaders/geglu.comp13
-rw-r--r--ggml/src/vulkan-shaders/geglu_erf.comp27
-rw-r--r--ggml/src/vulkan-shaders/geglu_quick.comp11
-rw-r--r--ggml/src/vulkan-shaders/gelu_erf.comp39
-rw-r--r--ggml/src/vulkan-shaders/glu_head.comp15
-rw-r--r--ggml/src/vulkan-shaders/glu_main.comp29
-rw-r--r--ggml/src/vulkan-shaders/l2_norm.comp41
-rw-r--r--ggml/src/vulkan-shaders/mul_mm.comp200
-rw-r--r--ggml/src/vulkan-shaders/mul_mm_cm2.comp47
-rw-r--r--ggml/src/vulkan-shaders/reglu.comp9
-rw-r--r--ggml/src/vulkan-shaders/rms_norm.comp15
-rw-r--r--ggml/src/vulkan-shaders/roll.comp46
-rw-r--r--ggml/src/vulkan-shaders/rope_multi.comp16
-rw-r--r--ggml/src/vulkan-shaders/rope_neox.comp16
-rw-r--r--ggml/src/vulkan-shaders/rope_norm.comp16
-rw-r--r--ggml/src/vulkan-shaders/scale.comp2
-rw-r--r--ggml/src/vulkan-shaders/soft_max.comp24
-rw-r--r--ggml/src/vulkan-shaders/swiglu.comp9
-rw-r--r--ggml/src/vulkan-shaders/upscale.comp74
-rw-r--r--ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp57
-rw-r--r--ggml/src/vulkan-shaders/wkv7.comp91
29 files changed, 900 insertions, 165 deletions
diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/vulkan-shaders/CMakeLists.txt
index a22ea817..e1f613fb 100644
--- a/ggml/src/vulkan-shaders/CMakeLists.txt
+++ b/ggml/src/vulkan-shaders/CMakeLists.txt
@@ -27,5 +27,5 @@ endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
-target_compile_features(${TARGET} PRIVATE cxx_std_11)
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
diff --git a/ggml/src/vulkan-shaders/conv2d_dw.comp b/ggml/src/vulkan-shaders/conv2d_dw.comp
new file mode 100644
index 00000000..938c74da
--- /dev/null
+++ b/ggml/src/vulkan-shaders/conv2d_dw.comp
@@ -0,0 +1,105 @@
+#version 450
+
+#include "types.comp"
+
+layout (push_constant) uniform parameter
+{
+ uint ne;
+ uint batches;
+ uint channels;
+ uint dst_w;
+ uint dst_h;
+ uint src_w;
+ uint src_h;
+ uint knl_w;
+ uint knl_h;
+ int stride_x;
+ int stride_y;
+ int pad_x;
+ int pad_y;
+ int dilation_x;
+ int dilation_y;
+} p;
+
+layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
+layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
+layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
+ uint i0 = idx / p.dst_w;
+ uint dst_x = idx - i0 * p.dst_w;
+ uint i1 = i0 / p.dst_h;
+ uint dst_y = i0 - i1 * p.dst_h;
+ uint n = i1 / p.channels;
+ uint c = i1 - n * p.channels;
+
+ uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
+ uint knl_i = c * p.knl_h * p.knl_w;
+
+ FLOAT_TYPE sum = 0.0;
+ for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
+ if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
+ continue;
+ }
+ for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
+ if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
+ continue;
+ }
+ FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
+ FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
+ sum = fma(v, k, sum);
+ }
+ }
+ return sum;
+}
+
+FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
+ uint i0 = idx / p.channels;
+ uint c = idx - i0 * p.channels;
+ uint i1 = i0 / p.dst_w;
+ uint dst_x = i0 - i1 * p.dst_w;
+ uint n = i1 / p.dst_h;
+ uint dst_y = i1 - n * p.dst_h;
+
+ uint src_i = n * p.channels * p.src_h * p.src_w;
+ uint src_row = p.src_w * p.channels;
+ uint knl_row = p.knl_w * p.channels;
+
+ FLOAT_TYPE sum = 0.0;
+ for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
+ uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
+ if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
+ continue;
+ }
+ for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
+ uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
+ if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
+ continue;
+ }
+ FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
+ FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
+ sum = fma(v, k, sum);
+ }
+ }
+ return sum;
+}
+
+void main() {
+ uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+ if (idx >= p.ne) {
+ return;
+ }
+
+ FLOAT_TYPE result =
+#ifdef WHCN
+ conv_2d_dw_whcn(idx);
+#else
+ conv_2d_dw_cwhn(idx);
+#endif
+ dst_data[idx] = D_TYPE(result);
+}
+
diff --git a/ggml/src/vulkan-shaders/copy_to_quant.comp b/ggml/src/vulkan-shaders/copy_to_quant.comp
index 9c76437d..e06547e4 100644
--- a/ggml/src/vulkan-shaders/copy_to_quant.comp
+++ b/ggml/src/vulkan-shaders/copy_to_quant.comp
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
#endif // RTE16
#include "types.comp"
-#include "generic_unary_head.comp"
-#if defined(DATA_A_IQ4_NL)
-// 16 invocations needed for init_iq4nl_shmem
-layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
+#if defined(SET_ROWS) && QUANT_K == 1
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+const uint BLOCK_SIZE = 512;
#else
-layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+const uint BLOCK_SIZE = 32;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
+
+#if defined(SET_ROWS)
+#include "generic_binary_head.comp"
+layout (binding = 1) readonly buffer C {uvec2 data_i[];};
+layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
+#else
+#include "generic_unary_head.comp"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
+#endif
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
}
#endif
+#if defined(DATA_A_F32) || defined(DATA_A_F16)
+void quantize(uint dst_idx, uint src_idx)
+{
+ data_q[dst_idx] = A_TYPE(data_s[src_idx]);
+}
+#endif
+
+#if defined(DATA_A_BF16)
+void quantize(uint dst_idx, uint src_idx)
+{
+ data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
+}
+#endif
+
+#if defined(SET_ROWS)
+
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
- if (gl_LocalInvocationIndex.x != 0) {
+#endif
+
+ const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
+
+ if (idx >= p.ne) {
return;
}
+
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ uint i12 = fastmod(i03, p.ne12);
+ uint i11 = fastmod(i02, p.ne11);
+ uint i10 = i01;
+
+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
+
+ uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
+ uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
+
+ quantize(dst_idx, src0_idx);
+}
+
+#else
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
#endif
- const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
+ const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
@@ -240,3 +289,5 @@ void main() {
quantize(dst_idx, src_idx);
}
+
+#endif
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp
index 454b3411..45c6e773 100644
--- a/ggml/src/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/vulkan-shaders/flash_attn.comp
@@ -100,6 +100,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
+ uint32_t m_offset = 0;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ }
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -145,13 +149,13 @@ void main() {
}
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
- masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
+ masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
}
}
barrier();
diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp
index 1d3e6387..7defe72b 100644
--- a/ggml/src/vulkan-shaders/flash_attn_base.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_base.comp
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
+ uint32_t nem2;
+ uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
- uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
- uint32_t mask;
- uint32_t n_head_log2;
+ uint32_t mask_n_head_log2;
float m0;
float m1;
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;
+#define MASK_ENABLE_BIT (1<<16)
+#define N_LOG2_MASK 0xFFFF
+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+ uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
+
+ const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
index ad7594fe..486735fe 100644
--- a/ggml/src/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
@@ -125,6 +125,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
+ uint32_t m_offset = 0;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ }
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -178,12 +182,12 @@ void main() {
barrier();
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
barrier();
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
index 91caa184..274f48fc 100644
--- a/ggml/src/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
@@ -130,6 +130,11 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
+ uint32_t m_offset = 0;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+ }
+
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -148,14 +153,14 @@ void main() {
}
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
- coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
diff --git a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp
index a7e39568..0a17a9df 100644
--- a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -2,9 +2,9 @@
#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 32
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
@@ -12,48 +12,80 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
+ uint ne3;
uint k_num;
} p;
+shared float tmpsh[BLOCK_SIZE];
+
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
+ const uint iq3 = gl_WorkGroupID.z;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
- uint l_offset = D * N * k_num + n;
- uint m_offset = D * N * k_num + N + n;
+ uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
+ uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
- [[unroll]] for (uint k = 0; k < k_num; ++k) {
- float m = data_a[m_offset + k * lm_stride];
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
+ float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m);
}
+ // reduce across the workgroup
+ tmpsh[tid] = m_max;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ m_max = max(m_max, tmpsh[tid + s]);
+ tmpsh[tid] = m_max;
+ }
+ barrier();
+ }
+ m_max = tmpsh[0];
+
+ barrier();
+
// Compute L based on m_max
float L = 0;
- [[unroll]] for (uint k = 0; k < k_num; ++k) {
- float l = data_a[l_offset + k * lm_stride];
- float m = data_a[m_offset + k * lm_stride];
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
+ float l = data_a[l_offset + (k + tid) * lm_stride];
+ float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l;
}
+ // reduce across the workgroup
+ tmpsh[tid] = L;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ L += tmpsh[tid + s];
+ tmpsh[tid] = L;
+ }
+ barrier();
+ }
+ L = tmpsh[0];
+
L = 1.0 / L;
+ // D dimension is split across workgroups in the y dimension
+ uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory
- for (uint d = tid; d < D; d += BLOCK_SIZE) {
+ if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
- uint o_offset = D * N * k + D * n + d;
+ uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
O *= L;
- data_d[D * n + d] = O;
+ data_d[iq3 * D * N + D * n + d] = O;
}
}
diff --git a/ggml/src/vulkan-shaders/geglu.comp b/ggml/src/vulkan-shaders/geglu.comp
new file mode 100644
index 00000000..f4268ed2
--- /dev/null
+++ b/ggml/src/vulkan-shaders/geglu.comp
@@ -0,0 +1,13 @@
+#version 450
+
+#include "glu_head.comp"
+
+const float GELU_COEF_A = 0.044715f;
+const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+float op(float a, float b) {
+ const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
+ return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/vulkan-shaders/geglu_erf.comp b/ggml/src/vulkan-shaders/geglu_erf.comp
new file mode 100644
index 00000000..cbd4cb36
--- /dev/null
+++ b/ggml/src/vulkan-shaders/geglu_erf.comp
@@ -0,0 +1,27 @@
+#version 450
+
+#include "glu_head.comp"
+
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+const float p_erf = 0.3275911f;
+const float a1_erf = 0.254829592f;
+const float a2_erf = -0.284496736f;
+const float a3_erf = 1.421413741f;
+const float a4_erf = -1.453152027f;
+const float a5_erf = 1.061405429f;
+
+const float SQRT_2_INV = 0.70710678118654752440084436210484f;
+
+float op(float a, float b) {
+ const float a_div_sqr2 = a * SQRT_2_INV;
+ const float sign_x = sign(a_div_sqr2);
+ const float x = abs(a_div_sqr2);
+ const float t = 1.0f / (1.0f + p_erf * x);
+ const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+ const float erf_approx = sign_x * y;
+
+ return 0.5f * a * (1.0f + erf_approx) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/vulkan-shaders/geglu_quick.comp b/ggml/src/vulkan-shaders/geglu_quick.comp
new file mode 100644
index 00000000..3a2a6897
--- /dev/null
+++ b/ggml/src/vulkan-shaders/geglu_quick.comp
@@ -0,0 +1,11 @@
+#version 450
+
+#include "glu_head.comp"
+
+const float GELU_QUICK_COEF = -1.702f;
+
+float op(float a, float b) {
+ return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/vulkan-shaders/gelu_erf.comp b/ggml/src/vulkan-shaders/gelu_erf.comp
new file mode 100644
index 00000000..5fd5a5e7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/gelu_erf.comp
@@ -0,0 +1,39 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+ // ref: https://www.johndcook.com/blog/python_erf/
+ const float p_erf = 0.3275911f;
+ const float a1_erf = 0.254829592f;
+ const float a2_erf = -0.284496736f;
+ const float a3_erf = 1.421413741f;
+ const float a4_erf = -1.453152027f;
+ const float a5_erf = 1.061405429f;
+
+ const float SQRT_2_INV = 0.70710678118654752440084436210484f;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float a = float(data_a[i]);
+ const float a_div_sqr2 = a * SQRT_2_INV;
+ const float sign_x = sign(a_div_sqr2);
+ const float x = abs(a_div_sqr2);
+ const float t = 1.0f / (1.0f + p_erf * x);
+ const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+ const float erf_approx = sign_x * y;
+
+ data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
+}
diff --git a/ggml/src/vulkan-shaders/glu_head.comp b/ggml/src/vulkan-shaders/glu_head.comp
new file mode 100644
index 00000000..41a29889
--- /dev/null
+++ b/ggml/src/vulkan-shaders/glu_head.comp
@@ -0,0 +1,15 @@
+#extension GL_EXT_shader_16bit_storage : require
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+layout (push_constant) uniform parameter
+{
+ uint N;
+ uint ne00;
+ uint ne20;
+ uint mode;
+} p;
diff --git a/ggml/src/vulkan-shaders/glu_main.comp b/ggml/src/vulkan-shaders/glu_main.comp
new file mode 100644
index 00000000..85cf65a9
--- /dev/null
+++ b/ggml/src/vulkan-shaders/glu_main.comp
@@ -0,0 +1,29 @@
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.N) {
+ return;
+ }
+
+ const uint row = i / p.ne20;
+ const uint col = i - row * p.ne20;
+
+ if (p.mode == 0) {
+ // Default
+ const uint offset = p.ne00 / 2;
+ const uint idx = row * p.ne00 + col;
+
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
+ } else if (p.mode == 1) {
+ // Swapped
+ const uint offset = p.ne00 / 2;
+ const uint idx = row * p.ne00 + col;
+
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
+ } else {
+ // Split
+ const uint idx = row * p.ne00 + col;
+
+ data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/l2_norm.comp b/ggml/src/vulkan-shaders/l2_norm.comp
new file mode 100644
index 00000000..deba8c39
--- /dev/null
+++ b/ggml/src/vulkan-shaders/l2_norm.comp
@@ -0,0 +1,41 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE sum[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+ sum[tid] += xi * xi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum[tid] += sum[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp
index 26163b16..f4815499 100644
--- a/ggml/src/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/vulkan-shaders/mul_mm.comp
@@ -18,6 +18,7 @@
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
+uint _ne1;
+#ifdef COOPMAT
+shared uint _ne1_sh;
+#endif
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
- uint _ne1 = 0;
+#ifdef COOPMAT
+ // Spread the search across all elements in the first subgroup
+ if (gl_SubgroupID == 0) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
+ bool in_range = i < num_elements;
+ uint ii1 = i / p.nei0;
+ uint ii0 = i % p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_SubgroupInvocationID;
+ bool in_range = i < num_elements;
+ uint ii1 = i / p.nei0;
+ uint ii0 = i % p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx) {
+ row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+ }
+ _ne1 += subgroupBallotBitCount(ballot);
+ iter &= 15;
+ }
+ _ne1_sh = _ne1;
+ }
+
+ barrier();
+
+ _ne1 = _ne1_sh;
+#else
+ _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
}
barrier();
+#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
@@ -500,10 +546,9 @@ void main() {
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint ib32 = (idx % 128) / 16; // 0..7
- const uint ib8 = (idx % 128) / 4;
- const int i8 = 2 * int(idx % 4);
+ const uint ib = idx / 32; // 8 values per idx
+ const uint ib32 = (idx % 32) / 4; // 0..7
+ const uint ib8 = idx % 32;
const float d = float(data_a[ib].d);
const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +557,16 @@ void main() {
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
- const ivec2 gvec = ivec2(
- bitfieldExtract(grid, 2 * (i8), 2),
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
- );
- const vec2 v = dl * (vec2(gvec) + delta);
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ [[unroll]] for (int k = 0; k < 8; ++k) {
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
+ }
#elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint ib8 = (idx % 128) / 4;
+ const uint ib = idx / 32; // 8 values per idx
+ const uint ib8 = idx % 32;
const uint ib16 = ib8 / 2;
- const int i8 = 2 * int(idx % 4);
const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +577,17 @@ void main() {
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
- const ivec2 gvec = ivec2(
- bitfieldExtract(grid, 2 * (i8), 2),
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
- );
- const vec2 v = dl * (vec2(gvec) + delta);
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ [[unroll]] for (int k = 0; k < 8; ++k) {
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
+ }
#elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint ib32 = (idx % 128) / 16; // 0..7
- const uint ib8 = (idx / 4) % 4;
+ const uint ib = idx / 32; // 8 values per idx
+ const uint ib32 = (idx % 32) / 4; // 0..7
+ const uint ib8 = idx % 4;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +597,81 @@ void main() {
data_a[ib].qs[8*ib32 + 6],
data_a[ib].qs[8*ib32 + 7]
));
- const float db = d * 0.25 * (0.5 + (signs >> 28));
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ const uint sign = sign7 | (bitCount(sign7) << 7);
+ const uvec2 grid = iq2xxs_grid[qs];
+ const vec4 grid0 = vec4(unpack8(grid.x));
+ const vec4 grid1 = vec4(unpack8(grid.y));
+
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint ib32 = (idx % 128) / 16; // 0..7
- const uint ib8 = (idx / 4) % 4; // 0..3
+ const uint ib = idx / 32; // 8 values per idx
+ const uint ib32 = (idx % 32) / 4; // 0..7
+ const uint ib8 = idx % 4; // 0..3
const float d = float(data_a[ib].d);
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
- const float db = d * 0.25 * (0.5 + scale);
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
const uint sign7 = qs >> 9;
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ const uint sign = sign7 | (bitCount(sign7) << 7);
+ const uvec2 grid = iq2xs_grid[qs & 511];
+ const vec4 grid0 = vec4(unpack8(grid.x));
+ const vec4 grid1 = vec4(unpack8(grid.y));
+
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint ib8 = (idx % 128) / 4; // 0..31
- const uint ib32 = ib8 / 4; // 0..7
+ const uint ib = idx / 32; // 8 values per idx
+ const uint ib8 = idx % 32; // 0..31
+ const uint ib32 = ib8 / 4; // 0..7
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = data_a[ib].qs[ib8];
const uint qh = data_a[ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
const float d = float(data_a[ib].d);
- const float db = d * 0.25 * (0.5 + scale);
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
- const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
+ const vec4 grid0 = vec4(unpack8(grid.x));
+ const vec4 grid1 = vec4(unpack8(grid.y));
+
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = (idx % 128) / 2; // 0..63
+ const uint ib = idx / 64; // 4 values per idx
+ const uint iqs = idx % 64; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(data_a[ib].d);
@@ -631,33 +684,36 @@ void main() {
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
+ const uint grid = iq3xxs_grid[qs];
+ const vec4 v = db * vec4(unpack8(grid));
+
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = (idx % 128) / 2; // 0..63
+ const uint ib = idx / 64; // 4 values per idx
+ const uint iqs = idx % 64; // 0..63
const uint iqh = iqs / 8;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
const uint qh = data_a[ib].qh[iqh];
- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
const uint scale = data_a[ib].scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
+ const vec4 v = db * vec4(unpack8(grid));
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
diff --git a/ggml/src/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/vulkan-shaders/mul_mm_cm2.comp
index 91846575..29e4b5c9 100644
--- a/ggml/src/vulkan-shaders/mul_mm_cm2.comp
+++ b/ggml/src/vulkan-shaders/mul_mm_cm2.comp
@@ -162,17 +162,32 @@ void main() {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
- for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
+ bool in_range = i < num_elements;
+ uint ii1 = i / p.nei0;
+ uint ii0 = i % p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
- uint ii0 = i % p.nei0;
uint ii1 = i / p.nei0;
- uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ uint ii0 = i % p.nei0;
+ uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
}
_ne1 += subgroupBallotBitCount(ballot);
+ iter &= 15;
}
_ne1_sh = _ne1;
}
@@ -414,17 +429,31 @@ void main() {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
- sum = coopMatMulAdd(mat_a, mat_b, sum);
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+#endif
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
}
// Convert from ACC_TYPE to D_TYPE
diff --git a/ggml/src/vulkan-shaders/reglu.comp b/ggml/src/vulkan-shaders/reglu.comp
new file mode 100644
index 00000000..0073d8f7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/reglu.comp
@@ -0,0 +1,9 @@
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+ return max(a, 0.0f) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/vulkan-shaders/rms_norm.comp
index deb8ee99..6428ca7b 100644
--- a/ggml/src/vulkan-shaders/rms_norm.comp
+++ b/ggml/src/vulkan-shaders/rms_norm.comp
@@ -1,11 +1,13 @@
#version 450
-#include "generic_unary_head.comp"
+#include "generic_binary_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
+layout (constant_id = 1) const bool do_multiply = false;
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
const uint stride_sample = p.nb03;
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
+ uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,13 @@ void main() {
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
- [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
- data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
+ if (do_multiply) {
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+ }
+ } else {
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
+ }
}
}
diff --git a/ggml/src/vulkan-shaders/roll.comp b/ggml/src/vulkan-shaders/roll.comp
new file mode 100644
index 00000000..b9abe8de
--- /dev/null
+++ b/ggml/src/vulkan-shaders/roll.comp
@@ -0,0 +1,46 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+uint wrap_idx(int i, uint ne) {
+ if (i < 0) {
+ return i + ne;
+ } else if (i >= ne) {
+ return i - ne;
+ }
+ return i;
+}
+
+void main() {
+ const uint idx = get_idx();
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
+ const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
+ const uint i2_offset = i2*p.ne11*p.ne10;
+ const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
+
+ const uint p1 = floatBitsToUint(p.param1);
+ const uint p2 = floatBitsToUint(p.param2);
+ const int s0 = int(p1 >> 16) - 0x8000;
+ const int s1 = int(p1 & 0xFFFF) - 0x8000;
+ const int s2 = int(p2 >> 16) - 0x8000;
+ const int s3 = int(p2 & 0xFFFF) - 0x8000;
+
+ const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
+ const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
+ const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
+ const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
+
+ const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+ const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
+
+ data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
+}
diff --git a/ggml/src/vulkan-shaders/rope_multi.comp b/ggml/src/vulkan-shaders/rope_multi.comp
index 4f5b1a0e..5808710c 100644
--- a/ggml/src/vulkan-shaders/rope_multi.comp
+++ b/ggml/src/vulkan-shaders/rope_multi.comp
@@ -14,21 +14,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x;
- if (i0 >= p.n_dims) {
- const uint i = row_dst*ne0 + i0;
-
- data_d[i + 0] = data_a[i + 0];
- data_d[i + 1] = data_a[i + 1];
-
- return;
- }
-
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
+ if (i0 >= p.n_dims) {
+ data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
+ data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
+
+ return;
+ }
+
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;
diff --git a/ggml/src/vulkan-shaders/rope_neox.comp b/ggml/src/vulkan-shaders/rope_neox.comp
index db775c45..366a7b1c 100644
--- a/ggml/src/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/vulkan-shaders/rope_neox.comp
@@ -13,21 +13,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x;
- if (i0 >= p.n_dims) {
- const uint i = row_dst*ne0 + i0;
-
- data_d[i + 0] = data_a[i + 0];
- data_d[i + 1] = data_a[i + 1];
-
- return;
- }
-
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
+ if (i0 >= p.n_dims) {
+ data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
+ data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
+
+ return;
+ }
+
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
diff --git a/ggml/src/vulkan-shaders/rope_norm.comp b/ggml/src/vulkan-shaders/rope_norm.comp
index 4ad35e54..9643bca9 100644
--- a/ggml/src/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/vulkan-shaders/rope_norm.comp
@@ -13,21 +13,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x;
- if (i0 >= p.n_dims) {
- const uint i = row_dst*ne0 + i0;
-
- data_d[i + 0] = data_a[i + 0];
- data_d[i + 1] = data_a[i + 1];
-
- return;
- }
-
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
+ if (i0 >= p.n_dims) {
+ data_d[idst + 0] = data_a[ix + 0];
+ data_d[idst + 1] = data_a[ix + 1];
+
+ return;
+ }
+
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp
index 4663428d..f10b0a02 100644
--- a/ggml/src/vulkan-shaders/scale.comp
+++ b/ggml/src/vulkan-shaders/scale.comp
@@ -18,7 +18,7 @@ void main() {
continue;
}
- data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
+ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
idx += num_threads;
}
}
diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp
index 51fc2dc7..5bcd3b1e 100644
--- a/ggml/src/vulkan-shaders/soft_max.comp
+++ b/ggml/src/vulkan-shaders/soft_max.comp
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
{
uint KX;
uint KY;
+ uint ne00;
+ uint ne01;
+ uint ne02;
+ uint ne12;
+ uint ne13;
+ uint nb11;
+ uint nb12;
+ uint nb13;
float scale;
float max_bias;
float m0;
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
void soft_max(uint num_iters) {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
+
+ const uint32_t i03 = rowx / (p.ne01 * p.ne02);
+ const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
+ const uint32_t i01 = rowx % p.ne01;
+
+ uint rowy_start = 0;
+ if (p.KY > 0) {
+ rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
+ }
if (rowx >= p.nrows_x) {
return;
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
// ALiBi
if (p.max_bias > 0.0f) {
- const uint h = rowx/p.KY; // head index
+ const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
FLOAT_TYPE b = FLOAT_TYPE(0);
if (p.KY > 0 && col < p.KX) {
- b = data_b[rowy * p.KX + col];
+ b = data_b[rowy_start + col];
}
FLOAT_TYPE v = a * p.scale + slope * b;
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
if (idx < DATA_CACHE_SIZE) {
val = exp(data_cache[idx] - max_val);
} else {
- val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
+ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
}
sum += val;
if (idx < DATA_CACHE_SIZE) {
diff --git a/ggml/src/vulkan-shaders/swiglu.comp b/ggml/src/vulkan-shaders/swiglu.comp
new file mode 100644
index 00000000..a28e7c6c
--- /dev/null
+++ b/ggml/src/vulkan-shaders/swiglu.comp
@@ -0,0 +1,9 @@
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+ return a / (1.0f + exp(-a)) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/vulkan-shaders/upscale.comp b/ggml/src/vulkan-shaders/upscale.comp
index 6f607380..74771def 100644
--- a/ggml/src/vulkan-shaders/upscale.comp
+++ b/ggml/src/vulkan-shaders/upscale.comp
@@ -3,6 +3,7 @@
layout (push_constant) uniform parameter
{
uint ne; uint a_offset; uint d_offset;
+ uint ne00; uint ne01;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
@@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
+#define NEAREST 0
+#define BILINEAR 1
+#define ALIGN_CORNERS (1 << 8)
+
+layout (constant_id = 0) const uint scale_mode = 0;
+
+float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
+ const uint i00 = uint(i10 / p.sf0);
+ const uint i01 = uint(i11 / p.sf1);
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+
+ return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
+}
+
+float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+ const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
+
+ const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
+ const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
+ const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
+ const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
+
+ return
+ v00 * (1.0-d.x) * (1.0-d.y) +
+ v01 * d.x * (1.0-d.y) +
+ v10 * (1.0-d.x) * d.y +
+ v11 * d.x * d.y;
+}
+
+float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
+ const ivec2 ne0 = ivec2(p.ne00, p.ne01);
+
+ const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
+ const vec2 c0f = floor(c);
+ const vec2 d = c - c0f;
+ const ivec2 c0 = max(ivec2(c0f), 0);
+ const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
+
+ return fetch_bilinear(c0, c1, d, i12, i13);
+}
+
+float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
+ const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
+ const vec2 c0f = floor(c);
+ const vec2 d = c - c0f;
+ const ivec2 c0 = ivec2(c0f);
+ const ivec2 c1 = c0 + 1;
+
+ return fetch_bilinear(c0, c1, d, i12, i13);
+}
+
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
@@ -27,10 +83,18 @@ void main() {
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
- const uint i00 = uint(i10 / p.sf0);
- const uint i01 = uint(i11 / p.sf1);
- const uint i02 = uint(i12 / p.sf2);
- const uint i03 = uint(i13 / p.sf3);
+ float result;
+ switch (scale_mode) {
+ case NEAREST:
+ result = fetch_nearest(i10, i11, i12, i13);
+ break;
+ case BILINEAR:
+ result = interpolate_bilinear(i10, i11, i12, i13);
+ break;
+ case BILINEAR | ALIGN_CORNERS:
+ result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
+ break;
+ }
- data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
+ data_d[p.d_offset + idx] = D_TYPE(result);
}
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
index 65dd82de..293aa644 100644
--- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
- if ((tname == "q4_0") || (tname == "q4_1"))
+ if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
- else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
+ else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
load_vec_quant = "4";
if (tname == "bf16") {
@@ -497,9 +497,9 @@ void process_shaders() {
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@@ -518,6 +518,11 @@ void process_shaders() {
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
+ for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
+ string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
+ }
+
auto get_type_str = [](bool f16) {
return f16 ? "float16_t" : "float";
};
@@ -572,17 +577,10 @@ void process_shaders() {
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-
- string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -594,6 +592,17 @@ void process_shaders() {
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -637,8 +646,30 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+ string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+ string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
+ string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
+
+ string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+ // ============================== ik_llama.cpp
+ //
+ string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+ string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ //
+ // ============================== end ik_llama.cpp
+
for (auto &c : compiles) {
c.wait();
}
diff --git a/ggml/src/vulkan-shaders/wkv7.comp b/ggml/src/vulkan-shaders/wkv7.comp
new file mode 100644
index 00000000..88c1c02b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/wkv7.comp
@@ -0,0 +1,91 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#define BLOCK_SIZE 64
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+ uint B;
+ uint T;
+ uint C;
+ uint H;
+};
+
+layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
+layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
+layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
+layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
+layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
+layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
+layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
+layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
+
+shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
+
+void main() {
+ const uint head_size = BLOCK_SIZE;
+ const uint batch_id = gl_WorkGroupID.x / H;
+ const uint head_id = gl_WorkGroupID.x % H;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint state_size = C * head_size;
+ const uint n_seq_tokens = T / B;
+
+ if (batch_id >= B || head_id >= H) {
+ return;
+ }
+
+ A_TYPE state[BLOCK_SIZE];
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ + tid * head_size + i];
+ }
+
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+ for (uint t = start_t; t < end_t; t += C) {
+ barrier();
+ _r[tid] = r[t];
+ _w[tid] = w[t];
+ _k[tid] = k[t];
+ _a[tid] = a[t];
+ _b[tid] = b[t];
+ barrier();
+
+ A_TYPE sa = 0.0;
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+ vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+ sa += dot(s_vec, a_vec);
+ }
+
+ const A_TYPE v_val = v[t];
+ A_TYPE y = 0.0;
+
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+ vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+ vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ vec4 kv = k_vec * v_val;
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
+ y += dot(r_vec, s_vec);
+
+ state[j] = s_vec.x;
+ state[j+1] = s_vec.y;
+ state[j+2] = s_vec.z;
+ state[j+3] = s_vec.w;
+ }
+
+ dst[t] = y;
+ }
+
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ + tid * head_size + i] = state[i];
+ }
+}