summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders')
-rw-r--r--ggml/src/vulkan-shaders/CMakeLists.txt24
-rw-r--r--ggml/src/vulkan-shaders/acc.comp29
-rw-r--r--ggml/src/vulkan-shaders/add.comp25
-rw-r--r--ggml/src/vulkan-shaders/argmax.comp51
-rw-r--r--ggml/src/vulkan-shaders/argsort.comp10
-rw-r--r--ggml/src/vulkan-shaders/clamp.comp6
-rw-r--r--ggml/src/vulkan-shaders/concat.comp10
-rw-r--r--ggml/src/vulkan-shaders/contig_copy.comp49
-rw-r--r--ggml/src/vulkan-shaders/conv_transpose_1d.comp98
-rw-r--r--ggml/src/vulkan-shaders/copy.comp11
-rw-r--r--ggml/src/vulkan-shaders/copy_from_quant.comp51
-rw-r--r--ggml/src/vulkan-shaders/copy_to_quant.comp242
-rw-r--r--ggml/src/vulkan-shaders/cos.comp17
-rw-r--r--ggml/src/vulkan-shaders/count_equal.comp31
-rw-r--r--ggml/src/vulkan-shaders/dequant_funcs.comp422
-rw-r--r--ggml/src/vulkan-shaders/dequant_funcs_cm2.comp699
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq1_m.comp42
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq1_s.comp35
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq2_s.comp44
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq2_xs.comp43
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq2_xxs.comp48
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq3_s.comp39
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq3_xxs.comp49
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq4_nl.comp2
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq4_xs.comp34
-rw-r--r--ggml/src/vulkan-shaders/dequant_q4_k.comp64
-rw-r--r--ggml/src/vulkan-shaders/dequant_q5_k.comp68
-rw-r--r--ggml/src/vulkan-shaders/diag_mask_inf.comp2
-rw-r--r--ggml/src/vulkan-shaders/div.comp23
-rw-r--r--ggml/src/vulkan-shaders/flash_attn.comp337
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_base.comp162
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm1.comp360
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm2.comp267
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp59
-rw-r--r--ggml/src/vulkan-shaders/generic_binary_head.comp62
-rw-r--r--ggml/src/vulkan-shaders/generic_unary_head.comp55
-rw-r--r--ggml/src/vulkan-shaders/get_rows.comp17
-rw-r--r--ggml/src/vulkan-shaders/get_rows_quant.comp10
-rw-r--r--ggml/src/vulkan-shaders/group_norm.comp2
-rw-r--r--ggml/src/vulkan-shaders/im2col.comp87
-rw-r--r--ggml/src/vulkan-shaders/mul.comp23
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp31
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec.comp176
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_base.comp51
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq1_m.comp82
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq1_s.comp79
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq2_s.comp90
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq2_xs.comp87
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq2_xxs.comp87
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq3_s.comp90
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_iq3_xxs.comp88
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_nc.comp69
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_p021.comp125
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp157
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp148
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp203
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp230
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp153
-rw-r--r--ggml/src/vulkan-shaders/mul_mm.comp601
-rw-r--r--ggml/src/vulkan-shaders/mul_mm_cm2.comp441
-rw-r--r--ggml/src/vulkan-shaders/mul_mmq.comp442
-rw-r--r--ggml/src/vulkan-shaders/mul_mmq_funcs.comp99
-rw-r--r--ggml/src/vulkan-shaders/opt_step_adamw.comp42
-rw-r--r--ggml/src/vulkan-shaders/pad.comp4
-rw-r--r--ggml/src/vulkan-shaders/pool2d.comp74
-rw-r--r--ggml/src/vulkan-shaders/quantize_q8_1.comp77
-rw-r--r--ggml/src/vulkan-shaders/relu.comp2
-rw-r--r--ggml/src/vulkan-shaders/repeat.comp26
-rw-r--r--ggml/src/vulkan-shaders/repeat_back.comp37
-rw-r--r--ggml/src/vulkan-shaders/rms_norm.comp32
-rw-r--r--ggml/src/vulkan-shaders/rms_norm_back.comp55
-rw-r--r--ggml/src/vulkan-shaders/rope_head.comp14
-rw-r--r--ggml/src/vulkan-shaders/rope_multi.comp60
-rw-r--r--ggml/src/vulkan-shaders/rope_neox.comp34
-rw-r--r--ggml/src/vulkan-shaders/rope_norm.comp34
-rw-r--r--ggml/src/vulkan-shaders/rope_vision.comp47
-rw-r--r--ggml/src/vulkan-shaders/scale.comp20
-rw-r--r--ggml/src/vulkan-shaders/sigmoid.comp20
-rw-r--r--ggml/src/vulkan-shaders/silu_back.comp26
-rw-r--r--ggml/src/vulkan-shaders/sin.comp17
-rw-r--r--ggml/src/vulkan-shaders/soft_max.comp115
-rw-r--r--ggml/src/vulkan-shaders/soft_max_back.comp50
-rw-r--r--ggml/src/vulkan-shaders/square.comp6
-rw-r--r--ggml/src/vulkan-shaders/sub.comp29
-rw-r--r--ggml/src/vulkan-shaders/tanh.comp3
-rw-r--r--ggml/src/vulkan-shaders/test_bfloat16_support.comp7
-rw-r--r--ggml/src/vulkan-shaders/test_coopmat2_support.comp7
-rw-r--r--ggml/src/vulkan-shaders/test_coopmat_support.comp7
-rw-r--r--ggml/src/vulkan-shaders/test_integer_dot_support.comp7
-rw-r--r--ggml/src/vulkan-shaders/types.comp1285
-rw-r--r--ggml/src/vulkan-shaders/upscale.comp4
-rw-r--r--ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp588
-rw-r--r--ggml/src/vulkan-shaders/wkv6.comp87
93 files changed, 9035 insertions, 1019 deletions
diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/vulkan-shaders/CMakeLists.txt
index 10075db3..a22ea817 100644
--- a/ggml/src/vulkan-shaders/CMakeLists.txt
+++ b/ggml/src/vulkan-shaders/CMakeLists.txt
@@ -1,5 +1,29 @@
+cmake_minimum_required(VERSION 3.19)
+project("vulkan-shaders-gen" C CXX)
+
find_package (Threads REQUIRED)
+if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+ add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+ message(STATUS "Enabling coopmat glslc support")
+endif()
+if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+ add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+ message(STATUS "Enabling coopmat2 glslc support")
+endif()
+if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ message(STATUS "Enabling dot glslc support")
+endif()
+if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+ add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+ message(STATUS "Enabling bfloat16 glslc support")
+endif()
+if (GGML_VULKAN_SHADER_DEBUG_INFO)
+ add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
+ message(STATUS "Enabling shader debug info")
+endif()
+
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
diff --git a/ggml/src/vulkan-shaders/acc.comp b/ggml/src/vulkan-shaders/acc.comp
new file mode 100644
index 00000000..d896f1ef
--- /dev/null
+++ b/ggml/src/vulkan-shaders/acc.comp
@@ -0,0 +1,29 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.x;
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint offset = p.param3;
+ const uint src1_i = idx - offset;
+ const uint oz = src1_i / p.nb02;
+ const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
+ const uint ox = src1_i % p.nb01;
+
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+ } else {
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
+ }
+}
+
diff --git a/ggml/src/vulkan-shaders/add.comp b/ggml/src/vulkan-shaders/add.comp
index 3974845d..2b4085c4 100644
--- a/ggml/src/vulkan-shaders/add.comp
+++ b/ggml/src/vulkan-shaders/add.comp
@@ -1,14 +1,29 @@
#version 450
+#extension GL_EXT_shader_16bit_storage : require
+
#include "types.comp"
#include "generic_binary_head.comp"
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
}
diff --git a/ggml/src/vulkan-shaders/argmax.comp b/ggml/src/vulkan-shaders/argmax.comp
new file mode 100644
index 00000000..eaf4da34
--- /dev/null
+++ b/ggml/src/vulkan-shaders/argmax.comp
@@ -0,0 +1,51 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
+shared uint tmp[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint col = gl_LocalInvocationID.x;
+
+ if (col >= p.KX) {
+ return;
+ }
+ A_TYPE amax = data_a[row*p.KX + col];
+ tmp[col] = col;
+
+ for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
+ A_TYPE val = data_a[row*p.KX + i];
+ if (val > amax) {
+ amax = val;
+ tmp[col] = i;
+ }
+ }
+ tmpmax[col] = amax;
+
+ barrier();
+ [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
+ if (col < s && col + s < p.KX) {
+ if (tmpmax[col] < tmpmax[col + s]) {
+ tmpmax[col] = tmpmax[col + s];
+ tmp[col] = tmp[col + s];
+ }
+ }
+ barrier();
+ }
+
+ if (col == 0) {
+ data_d[row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/vulkan-shaders/argsort.comp
index e55414b0..d4fa45b1 100644
--- a/ggml/src/vulkan-shaders/argsort.comp
+++ b/ggml/src/vulkan-shaders/argsort.comp
@@ -29,20 +29,18 @@ void main() {
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
- if (col >= p.ncols_pad) {
- return;
- }
-
const uint row_offset = row * p.ncols;
// initialize indices
- dst_row[col] = col;
+ if (col < p.ncols_pad) {
+ dst_row[col] = col;
+ }
barrier();
for (uint k = 2; k <= p.ncols_pad; k *= 2) {
for (uint j = k / 2; j > 0; j /= 2) {
const uint ixj = col ^ j;
- if (ixj > col) {
+ if (col < p.ncols_pad && ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= p.ncols ||
(dst_row[ixj] < p.ncols && (p.order == ASC ?
diff --git a/ggml/src/vulkan-shaders/clamp.comp b/ggml/src/vulkan-shaders/clamp.comp
index 7071302a..1e5cb8da 100644
--- a/ggml/src/vulkan-shaders/clamp.comp
+++ b/ggml/src/vulkan-shaders/clamp.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_unary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = get_idx();
@@ -10,6 +12,6 @@ void main() {
return;
}
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
diff --git a/ggml/src/vulkan-shaders/concat.comp b/ggml/src/vulkan-shaders/concat.comp
index 08ab5514..9ee2f1fa 100644
--- a/ggml/src/vulkan-shaders/concat.comp
+++ b/ggml/src/vulkan-shaders/concat.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_binary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const int dim = p.param3;
@@ -28,8 +30,12 @@ void main() {
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
+ data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
#else
- data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
+ if (is_src0) {
+ data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
+ } else {
+ data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
+ }
#endif
}
diff --git a/ggml/src/vulkan-shaders/contig_copy.comp b/ggml/src/vulkan-shaders/contig_copy.comp
new file mode 100644
index 00000000..6567a8c5
--- /dev/null
+++ b/ggml/src/vulkan-shaders/contig_copy.comp
@@ -0,0 +1,49 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+#extension GL_EXT_control_flow_attributes : require
+
+const uint num_threads = 128;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ uint idx = get_idx();
+
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 4;
+
+ // fast path for when all four iterations are in-bounds
+ if (idx + (num_iter-1)*num_threads < p.ne) {
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+
+#if defined(DATA_D_BF16)
+ float f = float(data_a[get_aoffset() + idx]);
+ data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
+#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
+ data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
+#else
+ data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
+#endif
+ idx += num_threads;
+ }
+ } else {
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+
+#if defined(DATA_D_BF16)
+ float f = float(data_a[get_aoffset() + idx]);
+ data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
+#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
+ data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
+#else
+ data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
+#endif
+ idx += num_threads;
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/vulkan-shaders/conv_transpose_1d.comp
new file mode 100644
index 00000000..b17b4e83
--- /dev/null
+++ b/ggml/src/vulkan-shaders/conv_transpose_1d.comp
@@ -0,0 +1,98 @@
+#version 450
+
+#include "types.comp"
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
+
+layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
+
+layout (push_constant) uniform parameter {
+ uint32_t Cout;
+ uint32_t Cin;
+ uint32_t K;
+ uint32_t L;
+ uint32_t KL;
+
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb11;
+ uint32_t nb1;
+
+ int32_t s0;
+} p;
+
+
+uint32_t Cout_idx = gl_WorkGroupID.x;
+const uint32_t bs = gl_WorkGroupSize.x;
+uint32_t tid = gl_LocalInvocationID.x;
+// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
+uint32_t tmp_len = bs*p.s0+p.K;
+shared D_TYPE tmp[4096];
+
+uint splitWork(uint workSize){
+ return (bs + workSize -1) / bs;
+}
+
+void main(){
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
+ uint32_t idx = i*bs+tid;
+ if(idx < tmp_len){
+ tmp[idx] = 0.0;
+ }
+ }
+
+ uint32_t L_blocks = splitWork(p.L);
+ for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
+ if(L_block_id > 0){
+ barrier();
+ // Shift values in tmp to the current processing window
+ for(int i = 0; i < splitWork(tmp_len); i++){
+ uint32_t idx = i*bs+tid;
+ if(idx >= bs*p.s0 && idx < tmp_len){
+ tmp[idx-bs*p.s0] = tmp[idx];
+ tmp[idx] = 0.0;
+ }else if(idx >= p.K && idx < bs*p.s0){
+ tmp[idx] = 0.0;
+ }
+ }
+ }
+ barrier();
+
+ // Save contributions of the block to tmp
+ uint32_t L_idx = L_block_id*bs + tid;
+ for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
+ D_TYPE dp = 0.0;
+ for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
+ A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
+ if(L_idx < p.L){
+ B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
+ dp = fma(elemKrn, elemInp, dp);
+ }
+ }
+ tmp[tid*p.s0 + K_idx] += dp;
+ barrier();
+ }
+
+ // Save the computed values except the last block that can have different size
+ uint32_t KLb_idx = L_block_id*bs*p.s0;
+ if(L_block_id < L_blocks-1){
+ for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
+ uint32_t sh_idx = p.s0*tid+s0_idx;
+ uint32_t KL_idx = KLb_idx+sh_idx;
+ if(KL_idx < p.KL){
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
+ }
+ }
+ }
+ }
+
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
+ uint32_t idx = i*bs+tid;
+ uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
+ if(KL_idx < p.KL){
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/copy.comp b/ggml/src/vulkan-shaders/copy.comp
index c26917c0..f476a2e3 100644
--- a/ggml/src/vulkan-shaders/copy.comp
+++ b/ggml/src/vulkan-shaders/copy.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_unary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = get_idx();
@@ -10,9 +12,12 @@ void main() {
return;
}
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
+#if defined(DATA_D_BF16)
+ float f = float(data_a[get_aoffset() + src0_idx(idx)]);
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
+#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
- data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
+ data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
#endif
}
diff --git a/ggml/src/vulkan-shaders/copy_from_quant.comp b/ggml/src/vulkan-shaders/copy_from_quant.comp
new file mode 100644
index 00000000..dbc7daa3
--- /dev/null
+++ b/ggml/src/vulkan-shaders/copy_from_quant.comp
@@ -0,0 +1,51 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+#include "dequant_funcs.comp"
+
+#if defined(DATA_A_IQ4_NL)
+// 16 invocations needed for init_iq4nl_shmem
+layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
+#else
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+#endif
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+ if (gl_LocalInvocationIndex.x != 0) {
+ return;
+ }
+#endif
+
+ const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ uint dst_idx = get_doffset() + dst_idx(idx);
+ uint src_idx = src0_idx_quant(idx, QUANT_K);
+
+ const uint a_offset = 0;
+ const uint ib = src_idx;
+ const vec2 dm = get_dm(ib, a_offset);
+
+ [[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
+ vec4 v = dequantize4(ib, j / QUANT_R, a_offset);
+ v = v * dm.x + vec4(dm.y);
+
+#if QUANT_R == 2
+ data_d[dst_idx + j/2 + 0] = v[0];
+ data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];
+ data_d[dst_idx + j/2 + 1] = v[2];
+ data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];
+#else
+ data_d[dst_idx + j + 0] = v[0];
+ data_d[dst_idx + j + 1] = v[1];
+ data_d[dst_idx + j + 2] = v[2];
+ data_d[dst_idx + j + 3] = v[3];
+#endif
+ }
+}
diff --git a/ggml/src/vulkan-shaders/copy_to_quant.comp b/ggml/src/vulkan-shaders/copy_to_quant.comp
new file mode 100644
index 00000000..9c76437d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/copy_to_quant.comp
@@ -0,0 +1,242 @@
+#version 450
+
+#if RTE16
+#extension GL_EXT_spirv_intrinsics : enable
+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
+#endif // RTE16
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+#if defined(DATA_A_IQ4_NL)
+// 16 invocations needed for init_iq4nl_shmem
+layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
+#else
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+#endif
+
+layout (binding = 0) readonly buffer S {float data_s[];};
+layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
+
+#if defined(DATA_A_Q4_0)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0;
+ float vmax = 0.0;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
+ const float v = data_s[src_idx + j];
+ if (amax < abs(v)) {
+ amax = abs(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -8;
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
+ const float x0 = data_s[src_idx + 0 + j]*id;
+ const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
+
+ const uint xi0 = min(15, int(x0 + 8.5));
+ const uint xi1 = min(15, int(x1 + 8.5));
+
+ data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
+ }
+}
+#endif
+
+#if defined(DATA_A_Q4_1)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float vmin = 1.0/0.0;
+ float vmax = -vmin;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
+ const float v = data_s[src_idx + j];
+
+ if (v < vmin) vmin = v;
+ if (v > vmax) vmax = v;
+ }
+
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+ data_q[dst_idx].m = float16_t(vmin);
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
+ const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
+ const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
+
+ const uint xi0 = min(15, int(x0 + 0.5));
+ const uint xi1 = min(15, int(x1 + 0.5));
+
+ data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
+ }
+}
+#endif
+
+#if defined(DATA_A_Q5_0)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0;
+ float vmax = 0.0;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
+ const float v = data_s[src_idx + j];
+ if (amax < abs(v)) {
+ amax = abs(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -16;
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+
+ uint32_t qh = 0;
+ [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
+ const float x0 = data_s[src_idx + 0 + j]*id;
+ const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
+
+ const uint xi0 = min(31, int(x0 + 16.5));
+ const uint xi1 = min(31, int(x1 + 16.5));
+
+ data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
+ }
+ data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
+ data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
+}
+#endif
+
+#if defined(DATA_A_Q5_1)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float min = data_s[src_idx + 0];
+ float max = min;
+
+ [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
+ const float v = data_s[src_idx + j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = (d != 0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+ data_q[dst_idx].m = float16_t(min);
+
+ uint32_t qh = 0;
+ [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
+ const float x0 = (data_s[src_idx + 0 + j] - min)*id;
+ const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
+
+ const uint xi0 = uint(x0 + 0.5);
+ const uint xi1 = uint(x1 + 0.5);
+
+ data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
+ }
+ data_q[dst_idx].qh = qh;
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0; // absolute max
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
+ const float v = data_s[src_idx + j];
+ amax = max(amax, abs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
+ const float x0 = data_s[src_idx + j]*id;
+
+ data_q[dst_idx].qs[j] = int8_t(round(x0));
+ }
+}
+#endif
+
+#if defined(DATA_A_IQ4_NL)
+uint best_index(float x) {
+ if (x <= kvalues_iq4nl[0]) return 0;
+ if (x >= kvalues_iq4nl[15]) return 15;
+ int ml = 0, mu = 15;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
+ }
+ return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
+}
+
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0;
+ float vmax = 0.0;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
+ const float v = data_s[src_idx + j];
+ if (amax < abs(v)) {
+ amax = abs(v);
+ vmax = v;
+ }
+ }
+
+ float d = vmax / kvalues_iq4nl[0];
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ float sumqx = 0, sumq2 = 0;
+ [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
+ const float x0 = data_s[src_idx + 0 + j]*id;
+ const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
+ const uint xi0 = best_index(x0);
+ const uint xi1 = best_index(x1);
+ data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
+ const float v0 = kvalues_iq4nl[xi0];
+ const float v1 = kvalues_iq4nl[xi1];
+ const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
+ const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
+ sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+ }
+
+ data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
+
+}
+#endif
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+ if (gl_LocalInvocationIndex.x != 0) {
+ return;
+ }
+#endif
+
+ const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ uint dst_idx = dst_idx_quant(idx, QUANT_K);
+ uint src_idx = get_aoffset() + src0_idx(idx);
+
+ quantize(dst_idx, src_idx);
+}
diff --git a/ggml/src/vulkan-shaders/cos.comp b/ggml/src/vulkan-shaders/cos.comp
new file mode 100644
index 00000000..0b8d02f5
--- /dev/null
+++ b/ggml/src/vulkan-shaders/cos.comp
@@ -0,0 +1,17 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
+}
diff --git a/ggml/src/vulkan-shaders/count_equal.comp b/ggml/src/vulkan-shaders/count_equal.comp
new file mode 100644
index 00000000..d9345497
--- /dev/null
+++ b/ggml/src/vulkan-shaders/count_equal.comp
@@ -0,0 +1,31 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.comp"
+#include "generic_head.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) buffer D {D_TYPE data_d[];};
+
+const uint CHUNK_SIZE = 512;
+
+void main() {
+ const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
+ const uint col = gl_LocalInvocationID.x;
+
+ uint count = 0;
+ [[unroll]]
+ for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
+ const uint idx = base + i + col;
+ if (idx >= p.KX) {
+ break;
+ }
+ count += uint(data_a[idx] == data_b[idx]);
+ }
+
+ atomicAdd(data_d[0], D_TYPE(count));
+}
diff --git a/ggml/src/vulkan-shaders/dequant_funcs.comp b/ggml/src/vulkan-shaders/dequant_funcs.comp
index d5b98973..0d9739d4 100644
--- a/ggml/src/vulkan-shaders/dequant_funcs.comp
+++ b/ggml/src/vulkan-shaders/dequant_funcs.comp
@@ -2,6 +2,15 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
+#include "types.comp"
+
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
@@ -14,55 +23,440 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
}
#endif
+#if defined(DATA_A_BF16)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
+}
+#endif
+
#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
+ return (vec2(vui & 0xF, vui >> 4) - 8.0f);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
+ return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
}
#endif
#if defined(DATA_A_Q4_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const float m = float(data_a[a_offset + ib].m);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(vui & 0xF, vui >> 4) * d + m;
+ return vec2(vui & 0xF, vui >> 4);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
+ return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
}
#endif
#if defined(DATA_A_Q5_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+ return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
+ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
+ return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f);
}
#endif
#if defined(DATA_A_Q5_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const float m = float(data_a[a_offset + ib].m);
const uint uint_qh = data_a[a_offset + ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+ return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint uint_qh = data_a_packed16[a_offset + ib].qh;
+ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
+ return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y);
}
#endif
#if defined(DATA_A_Q8_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
+ return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
+ const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
+ return vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
+#if defined(DATA_A_IQ1_S)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint ib8 = iqs / 8;
+ const int i8 = int(iqs % 8);
+ const uint qh = data_a[a_offset + ib].qh[ib32];
+ const uint qs = data_a[a_offset + ib].qs[ib8];
+ const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+ const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
+ const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
+ // Signed bitfield extract.
+ const ivec2 gvec = ivec2(
+ bitfieldExtract(grid, 2 * (i8), 2),
+ bitfieldExtract(grid, 2 * (i8 + 1), 2)
+ );
+ return dl * (vec2(gvec) + delta);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint ib8 = iqs / 8;
+ const int i8 = int(iqs % 8);
+ const uint qh = data_a[a_offset + ib].qh[ib32];
+ const uint qs = data_a[a_offset + ib].qs[ib8];
+ const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+ const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
+ // Signed bitfield extract.
+ const ivec4 gvec = ivec4(
+ bitfieldExtract(grid, 2 * (i8), 2),
+ bitfieldExtract(grid, 2 * (i8 + 1), 2),
+ bitfieldExtract(grid, 2 * (i8 + 2), 2),
+ bitfieldExtract(grid, 2 * (i8 + 3), 2)
+ );
+ return dl * (vec4(gvec) + delta);
+}
+#endif
+
+#if defined(DATA_A_IQ1_M)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint ib8 = iqs / 8;
+ const uint ib16 = iqs / 16;
+ const int i8 = int(iqs % 8);
+ const uint sc = data_a[a_offset + ib].scales[iqs / 64];
+ const uint qs = data_a[a_offset + ib].qs[ib8];
+ const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
+ const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
+ // Signed bitfield extract.
+ const ivec2 gvec = ivec2(
+ bitfieldExtract(grid, 2 * (i8), 2),
+ bitfieldExtract(grid, 2 * (i8 + 1), 2)
+ );
+ return dl * (vec2(gvec) + delta);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib8 = iqs / 8;
+ const uint ib16 = iqs / 16;
+ const int i8 = int(iqs % 8);
+ const uint sc = data_a[a_offset + ib].scales[iqs / 64];
+ const uint qs = data_a[a_offset + ib].qs[ib8];
+ const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
+ const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
+ // Signed bitfield extract.
+ const ivec4 gvec = ivec4(
+ bitfieldExtract(grid, 2 * (i8), 2),
+ bitfieldExtract(grid, 2 * (i8 + 1), 2),
+ bitfieldExtract(grid, 2 * (i8 + 2), 2),
+ bitfieldExtract(grid, 2 * (i8 + 3), 2)
+ );
+ return dl * (vec4(gvec) + delta);
+}
+#endif
+
+#if defined(DATA_A_IQ2_XXS)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint ib8 = (iqs / 8) % 4;
+ const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
+ // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
+ const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
+ data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
+ const float db = 0.25 * (0.5 + (signs >> 28));
+ const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
+ // Add parity bit
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const uint sign = sign8 >> (iqs % 8);
+ const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ return db * vec2(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0)
+ );
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint ib8 = (iqs / 8) % 4;
+ const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
+ // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
+ const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
+ data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
+ const float db = 0.25 * (0.5 + (signs >> 28));
+ const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
+ // Add parity bit
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const uint sign = sign8 >> (iqs % 8);
+ const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ bool sign2 = (sign & 4) != 0;
+ bool sign3 = (sign & 8) != 0;
+ return db * vec4(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0),
+ grid.z * (sign2 ? -1.0 : 1.0),
+ grid.w * (sign3 ? -1.0 : 1.0)
+ );
+}
+#endif
+
+#if defined(DATA_A_IQ2_XS)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
+ const uint qs = data_a[a_offset + ib].qs[iqs / 8];
+ const float db = 0.25 * (0.5 + scale);
+ const uint sign7 = qs >> 9;
+ // Add parity bit
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const uint sign = sign8 >> (iqs % 8);
+ const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ return db * vec2(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0)
+ );
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
+ const uint qs = data_a[a_offset + ib].qs[iqs / 8];
+ const float db = 0.25 * (0.5 + scale);
+ const uint sign7 = qs >> 9;
+ // Add parity bit
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const uint sign = sign8 >> (iqs % 8);
+ const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ bool sign2 = (sign & 4) != 0;
+ bool sign3 = (sign & 8) != 0;
+ return db * vec4(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0),
+ grid.z * (sign2 ? -1.0 : 1.0),
+ grid.w * (sign3 ? -1.0 : 1.0)
+ );
+}
+#endif
+
+#if defined(DATA_A_IQ2_S)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint ib8 = iqs / 8;
+
+ const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
+ const uint qs = data_a[a_offset + ib].qs[ib8];
+ const uint qh = data_a[a_offset + ib].qh[ib32];
+ const uint qhshift = 2 * (ib8 % 4);
+ const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
+
+ const float db = 0.25 * (0.5 + scale);
+ const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ return db * vec2(
+ grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
+ grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
+ );
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint ib8 = iqs / 8;
+
+ const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
+ const uint qs = data_a[a_offset + ib].qs[ib8];
+ const uint qh = data_a[a_offset + ib].qh[ib32];
+ const uint qhshift = 2 * (ib8 % 4);
+ const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
+
+ const float db = 0.25 * (0.5 + scale);
+ const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ bool sign2 = (sign & 4) != 0;
+ bool sign3 = (sign & 8) != 0;
+ return db * vec4(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0),
+ grid.z * (sign2 ? -1.0 : 1.0),
+ grid.w * (sign3 ? -1.0 : 1.0)
+ );
+}
+#endif
+
+#if defined(DATA_A_IQ3_XXS)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint ib4 = iqs / 4;
+ const uint ib32 = iqs / 32;
+ const uint is = QUANT_K / 4 + 4 * ib32;
+ const uint qs = data_a[a_offset + ib].qs[ib4];
+ // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
+ const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
+ data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
+ const float db = 0.5 * (0.5 + (signs >> 28));
+ const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
+ // Add parity bit
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const uint sign = sign8 >> (iqs % 8);
+ const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ return db * vec2(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0)
+ );
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib4 = iqs / 4;
+ const uint ib32 = iqs / 32;
+ const uint is = QUANT_K / 4 + 4 * ib32;
+ const uint qs = data_a[a_offset + ib].qs[ib4];
+ const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
+ data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
+ const float db = 0.5 * (0.5 + (signs >> 28));
+ const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
+ // Add parity bit
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const uint sign = sign8 >> (iqs % 8);
+ const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ bool sign2 = (sign & 4) != 0;
+ bool sign3 = (sign & 8) != 0;
+ return db * vec4(
+ grid.x * (sign0 ? -1.0 : 1.0),
+ grid.y * (sign1 ? -1.0 : 1.0),
+ grid.z * (sign2 ? -1.0 : 1.0),
+ grid.w * (sign3 ? -1.0 : 1.0)
+ );
+}
+#endif
+
+#if defined(DATA_A_IQ3_S)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint qs = data_a[a_offset + ib].qs[iqs / 4];
+ const uint qh = data_a[a_offset + ib].qh[iqs / 32];
+ const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
+ const uint scale = data_a[a_offset + ib].scales[iqs / 64];
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
+ return db * vec2(
+ int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
+ int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
+ );
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib4 = iqs / 4;
+ const uint ib32 = iqs / 32;
+ const uint qs = data_a[a_offset + ib].qs[ib4];
+ const uint qh = data_a[a_offset + ib].qh[ib32];
+ const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
+ const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
+ bool sign0 = (sign & 1) != 0;
+ bool sign1 = (sign & 2) != 0;
+ bool sign2 = (sign & 4) != 0;
+ bool sign3 = (sign & 8) != 0;
+ const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
+ return db * vec4(
+ int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
+ int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
+ int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
+ int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
+ );
+}
+#endif
+
+#if defined(DATA_A_IQ4_XS)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint iq = 16 * ib32 + (iqs % 16);
+
+ const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+ const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
+ const uint qshift = (iqs & 16) >> 2;
+ u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
+ qs = (qs >> qshift) & uint8_t(0xF);
+
+ const float dl = float(int(sl | (sh << 4)) - 32);
+ return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint ib32 = iqs / 32;
+ const uint iq = 16 * ib32 + (iqs % 16);
+
+ const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+ const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
+ const uint qshift = (iqs & 16) >> 2;
+ u8vec4 qs = u8vec4(
+ data_a[a_offset + ib].qs[iq + 0],
+ data_a[a_offset + ib].qs[iq + 1],
+ data_a[a_offset + ib].qs[iq + 2],
+ data_a[a_offset + ib].qs[iq + 3]
+ );
+ qs = (qs >> qshift) & uint8_t(0xF);
+
+ const float dl = float(int(sl | (sh << 4)) - 32);
+ return dl * vec4(
+ kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
+ kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
}
#endif
#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
+ return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
+ return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
+}
+#endif
+
+#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
+vec2 get_dm(uint ib, uint a_offset) {
+ return vec2(0, 0);
+}
+#endif
+
+#if defined(DATA_A_IQ1_M)
+vec2 get_dm(uint ib, uint a_offset) {
+ const uint16_t[4] scales = data_a[a_offset + ib].scales;
+ const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
+ const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
+ return vec2(d, 0);
+}
+#endif
+
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
+vec2 get_dm(uint ib, uint a_offset) {
+ return vec2(float(data_a[a_offset + ib].d), 0);
+}
+#endif
+
+#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
+vec2 get_dm(uint ib, uint a_offset) {
+ return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
}
#endif
diff --git a/ggml/src/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/vulkan-shaders/dequant_funcs_cm2.comp
new file mode 100644
index 00000000..9cb7da2d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_funcs_cm2.comp
@@ -0,0 +1,699 @@
+
+#include "types.comp"
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
+ block_q4_0_packed16 block;
+};
+
+float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
+ qs >>= shift;
+ qs &= 0x0F0F;
+ qs = unpack8(qs)[idx & 1];
+ float16_t ret = (float16_t(qs) - float16_t(8)) * d;
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
+ block_q4_1 block;
+};
+
+float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const float16_t m = bl.block.m;
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx & 0xF;
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = bl.block.qs[iqs];
+ qs >>= shift;
+ qs &= 0xF;
+ float16_t ret = float16_t(qs) * d + m;
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
+ block_q5_0 block;
+};
+
+float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx & 0xF;
+
+ const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
+ const uint qh = ((uint_qh >> idx) << 4) & 0x10;
+
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = bl.block.qs[iqs];
+ qs >>= shift;
+ qs &= 0xF;
+
+ float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
+ block_q5_1 block;
+};
+
+float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const float16_t m = bl.block.m;
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx & 0xF;
+
+ const uint uint_qh = bl.block.qh;
+ const uint qh = ((uint_qh >> idx) << 4) & 0x10;
+
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = bl.block.qs[iqs];
+ qs >>= shift;
+ qs &= 0xF;
+
+ float16_t ret = float16_t(qs | qh) * d + m;
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
+ block_q8_0_packed16 block;
+};
+
+float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx;
+
+ // Load 16b and select the byte for this element
+ int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
+ float16_t ret = float16_t(qs) * d;
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
+ block_q2_K block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
+ block_q2_K_packed16 block;
+};
+
+float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
+ const f16vec2 d = bl.block.d;
+ const uint idx = coordInBlock[1];
+
+ const uint scalesi = (idx & 0xF0) >> 4; // 0..15
+ const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
+
+ uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
+ qs = (qs >> qsshift) & 0x0303;
+ qs = unpack8(qs)[idx & 1];
+
+ const uint scales = bl.block.scales[scalesi];
+ float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
+ block_q3_K block;
+};
+
+float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx;
+
+ const uint n = iqs / 128; // 0,1
+ const uint qsi = n * 32 + (iqs % 32); // 0..63
+ const uint hmi = (iqs % 32); // 0..31
+ const uint j = (iqs % 128) / 8; // 0..15
+ const uint is = iqs / 16; // 0..15
+ const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
+ const uint qsshift = halfsplit * 2; // 0,2,4,6
+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
+
+ uint32_t scaleidx0 = (is < 8) ? is : (is-8);
+ uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
+ uint32_t scaleidx1 = is + 8 - (is/4)*4;
+ uint32_t scaleidx1shift = (is/4)*2;
+
+ const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
+
+ const float16_t dl = bl.block.d * float16_t(us - 32);
+
+ float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
+
+ return ret;
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
+ block_q4_K block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
+ block_q4_K_packed16 block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
+ block_q4_K_packed128 block;
+};
+
+#if defined(IS_MUL_MM2)
+
+// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
+// into shared memory and then process the whole tile using those scales.
+// There is a fetch function that loads into private variables and then a store
+// function that stores into shared memory.
+// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
+// the part that fetches from the structure (which has a different block layout).
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+const uint shAscales_stride = (BM + 2);
+// 1 scale per 32 elements -> 8 scales per block, per row
+shared vec2 shAscales[8 * shAscales_stride];
+uvec4 row_v;
+#endif
+
+#if defined(DATA_A_Q4_K)
+layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
+
+void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
+{
+ uint tids_per_row = BLOCK_SIZE / BM;
+ uint is_per_tid = 8 / tids_per_row;
+ uint is_start = is_per_tid * (tid % tids_per_row);
+ uint tid_row = tid / tids_per_row;
+
+ uint row = ir_BM + tid_row;
+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
+ if (in_bounds || row < p.M) {
+ row_v = data_a_q4_k_packed128[block_index].q4k[0];
+ }
+}
+#endif
+#if defined(DATA_A_Q5_K)
+layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
+
+void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
+{
+ uint tids_per_row = BLOCK_SIZE / BM;
+ uint is_per_tid = 8 / tids_per_row;
+ uint is_start = is_per_tid * (tid % tids_per_row);
+ uint tid_row = tid / tids_per_row;
+
+ uint row = ir_BM + tid_row;
+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
+ if (in_bounds || row < p.M) {
+ row_v = data_a_q5_k_packed128[block_index].q5k[0];
+ }
+}
+#endif
+
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+void store_scalesQ4_K(uint tid)
+{
+ barrier();
+
+ uint tids_per_row = BLOCK_SIZE / BM;
+ uint is_per_tid = 8 / tids_per_row;
+ uint is_start = is_per_tid * (tid % tids_per_row);
+ uint tid_row = tid / tids_per_row;
+
+ [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
+ uint is = idx + is_start;
+ uvec4 v = row_v;
+ const vec2 loadd = vec2(unpackFloat2x16(v.x));
+
+ uint32_t sc;
+ uint32_t mbyte;
+
+ uint32_t scale0 = v.y;
+ uint32_t scale4 = v.z;
+ uint32_t scale8 = v.w;
+
+ uint32_t sc_lo = scale0;
+ uint32_t mb_lo = scale4;
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
+
+ sc = is < 4 ? sc_lo : sc_hi;
+ mbyte = is < 4 ? mb_lo : mb_hi;
+ sc = sc >> (8 * (is & 3));
+ mbyte = mbyte >> (8 * (is & 3));
+ sc &= 0x3F;
+ mbyte &= 0x3F;
+
+ const float d = loadd.x * float(sc);
+ const float m = loadd.y * float(mbyte);
+ shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
+ }
+
+ barrier();
+}
+#endif
+
+#endif
+
+float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
+ decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
+ const uint idx = coordInBlock[1];
+
+ const uint b = (idx & 0x20) >> 5; // 0,1
+ const uint is = (idx & 0xE0) >> 5; // 0..7
+
+#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
+ float d = v.x;
+ float m = v.y;
+#else
+ uvec4 v = bl128.block.q4k[0];
+ const vec2 loadd = vec2(unpackFloat2x16(v.x));
+
+ uint32_t sc;
+ uint32_t mbyte;
+
+ uint32_t scale0 = v.y;
+ uint32_t scale4 = v.z;
+ uint32_t scale8 = v.w;
+
+ uint32_t sc_lo = scale0;
+ uint32_t mb_lo = scale4;
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
+
+ sc = is < 4 ? sc_lo : sc_hi;
+ mbyte = is < 4 ? mb_lo : mb_hi;
+ sc = sc >> (8 * (is & 3));
+ mbyte = mbyte >> (8 * (is & 3));
+ sc &= 0x3F;
+ mbyte &= 0x3F;
+
+ const float d = loadd.x * float(sc);
+ const float m = loadd.y * float(mbyte);
+#endif
+
+ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
+ qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
+
+ float ret = d * float(qs) - m;
+
+ return float16_t(ret);
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
+ block_q5_K block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
+ block_q5_K_packed16 block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
+ block_q5_K_packed128 block;
+};
+
+float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
+ decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
+ const uint idx = coordInBlock[1];
+
+ const uint b = (idx & 0x20) >> 5; // 0,1
+ const uint is = (idx & 0xE0) >> 5; // 0..7
+
+#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
+ float d = v.x;
+ float m = v.y;
+#else
+ uvec4 v = bl128.block.q5k[0];
+
+ const f16vec2 loadd = unpackFloat2x16(v.x);
+
+ uint32_t sc;
+ uint32_t mbyte;
+
+ uint32_t scale0 = v.y;
+ uint32_t scale4 = v.z;
+ uint32_t scale8 = v.w;
+
+ uint32_t sc_lo = scale0;
+ uint32_t mb_lo = scale4;
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
+
+ sc = is < 4 ? sc_lo : sc_hi;
+ mbyte = is < 4 ? mb_lo : mb_hi;
+ sc = sc >> (8 * (is & 3));
+ mbyte = mbyte >> (8 * (is & 3));
+ sc &= 0x3F;
+ mbyte &= 0x3F;
+
+ const float16_t d = loadd.x * float16_t(sc);
+ const float16_t m = loadd.y * float16_t(mbyte);
+#endif
+
+ uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
+ qh = ((qh >> is) & 0x101) << 4;
+
+ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
+ qs = (qs >> (b * 4)) & 0x0F0F;
+ qs = unpack8(qs | qh)[idx & 1];
+
+ float ret = d * float(qs) - m;
+
+ return float16_t(ret);
+}
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
+ block_q6_K block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
+ block_q6_K_packed16 block;
+};
+
+float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
+ const uint idx = coordInBlock[1];
+
+ const uint b = (idx & 0x40) >> 6; // 0,1
+ const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
+ const uint is = (idx & 0xF0) >> 4; // 0..15
+
+ const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
+
+ uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
+ ql = (ql >> (b * 4)) & 0x0F0F;
+
+ uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
+ qh = ((qh >> qhshift) & 0x0303) << 4;
+
+ int q = unpack8(ql | qh)[idx & 1];
+
+ float16_t ret = dscale * float16_t(q - 32);
+
+ return ret;
+}
+
+#if defined(DATA_A_IQ1_S)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
+ block_iq1_s block;
+};
+
+float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+
+ const uint ib32 = (idx & 0xE0) >> 5;
+ const uint ib8 = (idx & 0xF8) >> 3;
+
+ const uint qh = bl.block.qh[ib32];
+ const uint qs = bl.block.qs[ib8];
+ const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+ const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
+
+ float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
+ return ret;
+}
+#endif
+
+#if defined(DATA_A_IQ1_M)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
+ block_iq1_m block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
+ block_iq1_m_packed64 block;
+};
+
+float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
+ const uint idx = coordInBlock[1];
+
+ uvec2 scales = unpack32(bl64.block.scales);
+ const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
+
+ const uint ib8 = (idx & 0xF8) >> 3;
+ const uint ib16 = (idx & 0xF0) >> 4;
+ const int i8 = int(idx % 8);
+ const uint sc = bl.block.scales[ib8 / 8];
+ const uint qs = bl.block.qs[ib8];
+ const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
+ const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
+ const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+ const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
+
+ float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
+ return ret;
+}
+#endif
+
+#if defined(DATA_A_IQ2_XXS)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
+ block_iq2_xxs block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
+ block_iq2_xxs_packed16 block;
+};
+
+float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+
+ const uint ib32 = (idx & 0xE0) >> 5; // 0..7
+ const uint ib8 = (idx & 0x18) >> 3; // 0..3
+ const uint iqs = 8 * ib32 + ib8;
+
+ const uint qs = bl.block.qs[iqs];
+ const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
+
+ const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
+ uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
+ sign |= bitCount(sign) << 7;
+
+ uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
+ g2 >>= (idx & 2) * 8;
+ const vec2 g = vec2(unpack8(g2));
+
+ vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
+ return float16_t(ret[idx & 1]);
+}
+#endif
+
+#if defined(DATA_A_IQ2_XS)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
+ block_iq2_xs block;
+};
+
+float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+
+ const uint is = (idx & 0xE0) >> 5; // 0..8
+ const uint sshift = (idx & 0x10) >> 2; // 0,4
+ const uint iqs = (idx & 0xF8) >> 3; // 0..63
+
+ const uint16_t qs = bl.block.qs[iqs];
+ const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
+
+ uint sign = uint(qs >> 9);
+ sign |= bitCount(sign) << 7;
+ uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
+ g2 >>= (idx & 2) * 8;
+ const vec2 g = vec2(unpack8(g2));
+
+ vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
+ return float16_t(ret[idx & 1]);
+}
+#endif
+
+#if defined(DATA_A_IQ2_S)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
+ block_iq2_s block;
+};
+
+float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ uint idx = coordInBlock[1];
+
+ const uint ib32 = (idx & 0xE0) >> 5; // 0..7
+ const uint ib8 = (idx & 0xF8) >> 3; // 0..31
+ const uint qhshift = 2 * (ib8 % 4);
+
+ const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
+ const uint qs = bl.block.qs[ib8];
+ const uint qh = bl.block.qh[ib32];
+ const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
+
+ const float d = float(bl.block.d);
+ const float db = d * 0.25 * (0.5 + scale);
+ const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
+ uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
+ g2 >>= (idx & 2) * 8;
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
+ return float16_t(v[idx & 1]);
+}
+#endif
+
+#if defined(DATA_A_IQ3_XXS)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
+ block_iq3_xxs block;
+};
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
+ block_iq3_xxs_packed16 block;
+};
+
+float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
+ uint idx = coordInBlock[1];
+
+ const uint iqs = (idx & 0xFC) >> 2; // 0..63
+ const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
+
+ const float d = float(bl.block.d);
+ const uint qs = bl.block.qs[iqs];
+ const uint signs = pack32(u16vec2(
+ bl16.block.qs[is/2+0],
+ bl16.block.qs[is/2+1]
+ ));
+ const float db = d * 0.5 * (0.5 + (signs >> 28));
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
+ const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
+ const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
+ return float16_t(v[idx & 1]);
+}
+#endif
+
+#if defined(DATA_A_IQ3_S)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
+ block_iq3_s block;
+};
+
+float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ uint idx = coordInBlock[1];
+
+ const uint iqs = (idx & 0xFC) >> 2; // 0..63
+ const uint iqh = (idx & 0xE0) >> 5;
+
+ const float d = float(bl.block.d);
+ const uint qs = bl.block.qs[iqs];
+ const uint qh = bl.block.qh[iqh];
+ const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
+ const uint scale = bl.block.scales[iqs / 16];
+ const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
+ const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
+
+ return float16_t(v[idx & 1]);
+}
+#endif
+
+#if defined(DATA_A_IQ4_XS)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
+ block_iq4_xs block;
+};
+
+float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+
+ const uint ib32 = (idx & 0xE0) >> 5; // 0..7
+
+ const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+ const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
+ const uint qshift = (idx & 16) >> 2;
+ const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
+
+ float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
+ return ret;
+}
+#endif
+
+#if defined(DATA_A_IQ4_NL)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
+ block_iq4_nl block;
+};
+
+float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float16_t d = bl.block.d;
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx & 0xF;
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = bl.block.qs[iqs];
+ qs >>= shift;
+ qs &= 0xF;
+ float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
+ return ret;
+}
+#endif
+
+#if defined(DATA_A_Q4_0)
+#define dequantFuncA dequantFuncQ4_0
+#elif defined(DATA_A_Q4_1)
+#define dequantFuncA dequantFuncQ4_1
+#elif defined(DATA_A_Q5_0)
+#define dequantFuncA dequantFuncQ5_0
+#elif defined(DATA_A_Q5_1)
+#define dequantFuncA dequantFuncQ5_1
+#elif defined(DATA_A_Q8_0)
+#define dequantFuncA dequantFuncQ8_0
+#elif defined(DATA_A_Q2_K)
+#define dequantFuncA dequantFuncQ2_K
+#elif defined(DATA_A_Q3_K)
+#define dequantFuncA dequantFuncQ3_K
+#elif defined(DATA_A_Q4_K)
+#define dequantFuncA dequantFuncQ4_K
+#define fetch_scales fetch_scalesQ4_K
+#define store_scales store_scalesQ4_K
+#elif defined(DATA_A_Q5_K)
+#define dequantFuncA dequantFuncQ5_K
+#define fetch_scales fetch_scalesQ5_K
+#define store_scales store_scalesQ4_K
+#elif defined(DATA_A_Q6_K)
+#define dequantFuncA dequantFuncQ6_K
+#elif defined(DATA_A_IQ1_S)
+#define dequantFuncA dequantFuncIQ1_S
+#elif defined(DATA_A_IQ1_M)
+#define dequantFuncA dequantFuncIQ1_M
+#elif defined(DATA_A_IQ2_XXS)
+#define dequantFuncA dequantFuncIQ2_XXS
+#elif defined(DATA_A_IQ2_XS)
+#define dequantFuncA dequantFuncIQ2_XS
+#elif defined(DATA_A_IQ2_S)
+#define dequantFuncA dequantFuncIQ2_S
+#elif defined(DATA_A_IQ3_XXS)
+#define dequantFuncA dequantFuncIQ3_XXS
+#elif defined(DATA_A_IQ3_S)
+#define dequantFuncA dequantFuncIQ3_S
+#elif defined(DATA_A_IQ4_XS)
+#define dequantFuncA dequantFuncIQ4_XS
+#elif defined(DATA_A_IQ4_NL)
+#define dequantFuncA dequantFuncIQ4_NL
+#endif
diff --git a/ggml/src/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/vulkan-shaders/dequant_iq1_m.comp
new file mode 100644
index 00000000..b604c188
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq1_m.comp
@@ -0,0 +1,42 @@
+#version 450
+
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 subblock (32 values with 2 scales)
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint ib32 = gl_LocalInvocationID.x % 8;
+ const uint ib64 = ib32 / 2;
+ const uint b_idx = 256 * ib + 32 * ib32;
+
+ const uint16_t[4] scales = data_a[ib].scales;
+ const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
+ const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
+
+ const uint sc = data_a[ib].scales[ib64];
+ [[unroll]] for (int l = 0; l < 4; ++l) {
+ const uint ib16 = 2 * ib32 + l / 2;
+ const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
+ const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
+ const uint qs = data_a[ib].qs[4 * ib32 + l];
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
+ [[unroll]] for (int j = 0; j < 8; ++j) {
+ data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/vulkan-shaders/dequant_iq1_s.comp
new file mode 100644
index 00000000..fd1e4e30
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq1_s.comp
@@ -0,0 +1,35 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 subblock (32 values with 2 scales)
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint ib32 = gl_LocalInvocationID.x % 8;
+ const uint b_idx = 256 * ib + 32 * ib32;
+
+ uint qh = data_a[ib].qh[ib32];
+ const float d = float(data_a[ib].d);
+ const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ const uint qs = data_a[ib].qs[4 * ib32 + l];
+ const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
+ const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
+ [[unroll]] for (int j = 0; j < 8; ++j) {
+ data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/vulkan-shaders/dequant_iq2_s.comp
new file mode 100644
index 00000000..48f6b65b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq2_s.comp
@@ -0,0 +1,44 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 subblock (32 values with 2 scales)
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint ib32 = gl_LocalInvocationID.x % 8;
+ const uint b_idx = 256 * ib + 32 * ib32;
+
+ const float d = float(data_a[ib].d);
+ const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
+ const vec2 db = d * (0.5 + scale) * 0.25;
+
+ uint qh = data_a[ib].qh[ib32];
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ uint qs = data_a[ib].qs[4 * ib32 + l];
+ const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
+ qs |= (qh << (8 - 2 * l)) & 0x300;
+ const uvec2 grid = iq2s_grid[qs & 511];
+ const u8vec4 grid0 = unpack8(grid.x);
+ const u8vec4 grid1 = unpack8(grid.y);
+ data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/vulkan-shaders/dequant_iq2_xs.comp
new file mode 100644
index 00000000..a08331c4
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq2_xs.comp
@@ -0,0 +1,43 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 subblock (32 values with 2 scales)
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint ib32 = gl_LocalInvocationID.x % 8;
+ const uint b_idx = 256 * ib + 32 * ib32;
+
+ const float d = float(data_a[ib].d);
+ const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
+ const vec2 db = d * (0.5 + scale) * 0.25;
+
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ uint16_t qs = data_a[ib].qs[4 * ib32 + l];
+ const uint sign7 = qs >> 9;
+ const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
+ const uvec2 grid = iq2xs_grid[qs & 511];
+ const u8vec4 grid0 = unpack8(grid.x);
+ const u8vec4 grid1 = unpack8(grid.y);
+ data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/vulkan-shaders/dequant_iq2_xxs.comp
new file mode 100644
index 00000000..e370690b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq2_xxs.comp
@@ -0,0 +1,48 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 scale block (32 values)
+ // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint is = gl_LocalInvocationID.x % 8;
+ const uint b_idx = 256 * ib + 32 * is;
+
+ const float d = float(data_a[ib].d);
+ uint signscale = pack32(u8vec4(
+ data_a[ib].qs[8*is + 4],
+ data_a[ib].qs[8*is + 5],
+ data_a[ib].qs[8*is + 6],
+ data_a[ib].qs[8*is + 7]
+ ));
+ const float db = d * (0.5 + (signscale >> 28)) * 0.25;
+
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
+ const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
+ const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]];
+ const u8vec4 grid0 = unpack8(grid.x);
+ const u8vec4 grid1 = unpack8(grid.y);
+ data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/vulkan-shaders/dequant_iq3_s.comp
new file mode 100644
index 00000000..c3f4bca5
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq3_s.comp
@@ -0,0 +1,39 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 scale nibble.
+ // Each block contains 4 scale bytes (8 scales) for 256 output values.
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint is = gl_LocalInvocationID.x % 8;
+ const uint b_idx = 256 * ib + 32 * is;
+
+ const float d = float(data_a[ib].d);
+ const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
+
+ // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
+ uint qh = data_a[ib].qh[is];
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ uint qs = data_a[ib].qs[8 * is + l];
+ uint gidx = qs | ((qh << (8 - l)) & 256);
+ uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
+ u8vec4 grid = unpack8(iq3s_grid[gidx]);
+ data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/vulkan-shaders/dequant_iq3_xxs.comp
new file mode 100644
index 00000000..a92b8296
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq3_xxs.comp
@@ -0,0 +1,49 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 scale block (32 values)
+ // 8 threads handle 1 superblock
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint is = gl_LocalInvocationID.x % 8;
+ const uint b_idx = 256 * ib + 32 * is;
+ const uint s_idx = QUANT_K / 4 + 4 * is;
+
+ const float d = float(data_a[ib].d);
+ uint signscale = pack32(u8vec4(
+ data_a[ib].qs[s_idx + 0],
+ data_a[ib].qs[s_idx + 1],
+ data_a[ib].qs[s_idx + 2],
+ data_a[ib].qs[s_idx + 3]
+ ));
+ const float db = d * (0.5 + (signscale >> 28)) * 0.5;
+
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
+ // Restore parity bit.
+ const uint sign8 = sign7 | (bitCount(sign7) << 7);
+ const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
+ const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
+ data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
+ data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/vulkan-shaders/dequant_iq4_nl.comp
index 34ef3da3..46d9ad15 100644
--- a/ggml/src/vulkan-shaders/dequant_iq4_nl.comp
+++ b/ggml/src/vulkan-shaders/dequant_iq4_nl.comp
@@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+ init_iq_shmem(gl_WorkGroupSize);
+
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
diff --git a/ggml/src/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/vulkan-shaders/dequant_iq4_xs.comp
new file mode 100644
index 00000000..f930852a
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq4_xs.comp
@@ -0,0 +1,34 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ // Each thread handles 1 subblock (1 scale and 32 quantized values)
+ const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ if (ib >= p.nel / 256) {
+ return;
+ }
+
+ const uint ib32 = gl_LocalInvocationID.x % 8;
+
+ const float d = float(data_a[ib].d);
+ // Scales are 6 bits
+ const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
+ | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
+ const float dl = d * (int(scale) - 32);
+
+ const uint b_idx = 256 * ib + 32 * ib32;
+ const uint q_idx = 16 * ib32;
+ [[unroll]] for (uint l = 0; l < 16; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
+ data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_k.comp b/ggml/src/vulkan-shaders/dequant_q4_k.comp
index 92acb754..987f113a 100644
--- a/ggml/src/vulkan-shaders/dequant_q4_k.comp
+++ b/ggml/src/vulkan-shaders/dequant_q4_k.comp
@@ -9,8 +9,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
+ const uint ib = gl_WorkGroupID.x * 256 + wgy;
+ if (ib >= p.M * p.K / QUANT_K) {
return;
}
@@ -20,37 +20,49 @@ void main() {
const uint is = 2 * il;
const uint n = 4;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
- const uint y_idx = i * QUANT_K + 64 * il + n * ir;
+ const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
- uint8_t sc;
- uint8_t m;
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is] & 63);
- m = uint8_t(data_a[i].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
- }
+ uint scidx0 = (is < 4) ? is : (is + 4);
+ uint scidx1 = (is < 4) ? is : (is - 4);
+ uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ uint scidxshift1 = (is < 4) ? 0 : 2;
+ uint mbidx0 = is + 4;
+ uint mbidx1 = (is < 4) ? is + 4 : is;
+ uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+ uint mbidxshift0 = (is < 4) ? 0 : 4;
+ uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ uint mbidxshift1 = (is < 4) ? 0 : 2;
+
+ uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+ uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
const FLOAT_TYPE d1 = dall * sc;
- const FLOAT_TYPE m1 = dmin * m;
-
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is + 1] & 63);
- m = uint8_t(data_a[i].scales[is + 5] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
- }
+ const FLOAT_TYPE m1 = dmin * mbyte;
+
+ scidx0 = (is < 4) ? is + 1 : (is + 5);
+ scidx1 = (is < 4) ? is + 1 : (is - 3);
+ scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ scidxshift1 = (is < 4) ? 0 : 2;
+ mbidx0 = is + 5;
+ mbidx1 = (is < 4) ? is + 5 : is + 1;
+ mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+ mbidxshift0 = (is < 4) ? 0 : 4;
+ mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ mbidxshift1 = (is < 4) ? 0 : 2;
+
+ sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+ mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
const FLOAT_TYPE d2 = dall * sc;
- const FLOAT_TYPE m2 = dmin * m;
+ const FLOAT_TYPE m2 = dmin * mbyte;
[[unroll]] for (uint l = 0; l < n; ++l) {
- data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
- data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
+ data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1);
+ data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2);
}
}
}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_k.comp b/ggml/src/vulkan-shaders/dequant_q5_k.comp
index f314a76d..6db5403b 100644
--- a/ggml/src/vulkan-shaders/dequant_q5_k.comp
+++ b/ggml/src/vulkan-shaders/dequant_q5_k.comp
@@ -9,8 +9,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
+ const uint ib = gl_WorkGroupID.x * 256 + wgy;
+ if (ib >= p.M * p.K / QUANT_K) {
return;
}
@@ -19,40 +19,52 @@ void main() {
const uint ir = tid % 16;
const uint is = 2 * il;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
- const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
+ const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
const uint qh_idx = 2 * ir;
- uint8_t sc;
- uint8_t m;
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is] & 63);
- m = uint8_t(data_a[i].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
- }
+ uint scidx0 = (is < 4) ? is : (is + 4);
+ uint scidx1 = (is < 4) ? is : (is - 4);
+ uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ uint scidxshift1 = (is < 4) ? 0 : 2;
+ uint mbidx0 = is + 4;
+ uint mbidx1 = (is < 4) ? is + 4 : is;
+ uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+ uint mbidxshift0 = (is < 4) ? 0 : 4;
+ uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ uint mbidxshift1 = (is < 4) ? 0 : 2;
+
+ uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+ uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
const FLOAT_TYPE d1 = dall * sc;
- const FLOAT_TYPE m1 = dmin * m;
-
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is + 1] & 63);
- m = uint8_t(data_a[i].scales[is + 5] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
- }
+ const FLOAT_TYPE m1 = dmin * mbyte;
+
+ scidx0 = (is < 4) ? is + 1 : (is + 5);
+ scidx1 = (is < 4) ? is + 1 : (is - 3);
+ scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ scidxshift1 = (is < 4) ? 0 : 2;
+ mbidx0 = is + 5;
+ mbidx1 = (is < 4) ? is + 5 : is + 1;
+ mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+ mbidxshift0 = (is < 4) ? 0 : 4;
+ mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ mbidxshift1 = (is < 4) ? 0 : 2;
+
+ sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+ mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
const FLOAT_TYPE d2 = dall * sc;
- const FLOAT_TYPE m2 = dmin * m;
+ const FLOAT_TYPE m2 = dmin * mbyte;
const uint8_t hm1 = uint8_t(1 << (2 * il ));
const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
- data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
- data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
- data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
- data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
+ data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
+ data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
+ data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
+ data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
}
}
diff --git a/ggml/src/vulkan-shaders/diag_mask_inf.comp b/ggml/src/vulkan-shaders/diag_mask_inf.comp
index 4e68742b..26d8bc22 100644
--- a/ggml/src/vulkan-shaders/diag_mask_inf.comp
+++ b/ggml/src/vulkan-shaders/diag_mask_inf.comp
@@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
#include "types.comp"
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
diff --git a/ggml/src/vulkan-shaders/div.comp b/ggml/src/vulkan-shaders/div.comp
index 8cfce58b..9fb69c6c 100644
--- a/ggml/src/vulkan-shaders/div.comp
+++ b/ggml/src/vulkan-shaders/div.comp
@@ -3,12 +3,25 @@
#include "types.comp"
#include "generic_binary_head.comp"
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
}
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp
new file mode 100644
index 00000000..ce230a8f
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn.comp
@@ -0,0 +1,337 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#extension GL_KHR_shader_subgroup_shuffle : enable
+
+#include "types.comp"
+#include "flash_attn_base.comp"
+
+const uint32_t D_per_thread = D / D_split;
+
+const uint32_t cols_per_iter = WorkGroupSize / D_split;
+const uint32_t cols_per_thread = Bc / cols_per_iter;
+
+
+layout (binding = 0) readonly buffer Q {float data_q[];};
+layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
+layout (binding = 1) readonly buffer K {float16_t data_k[];};
+layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
+layout (binding = 2) readonly buffer V {float16_t data_v[];};
+layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
+layout (binding = 3) readonly buffer M {float16_t data_m[];};
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ uint32_t offset = (iq2 + r) * D + c;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ return elem;
+}
+
+shared FLOAT_TYPE tmpsh[WorkGroupSize];
+shared vec4 tmpshv4[WorkGroupSize];
+
+shared float masksh[Bc][Br];
+shared vec4 Qf[Br][D / 4];
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ init_indices();
+
+ const uint32_t tid = gl_LocalInvocationIndex;
+ const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
+ const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
+
+ uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+
+ [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (D / 4);
+ uint32_t r = (idx + tid) / (D / 4);
+ if (r < Br && d < D / 4 &&
+ i * Br + r < N) {
+ Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
+ }
+ }
+ barrier();
+
+ vec4 Of[Br][D_per_thread / 4];
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] = vec4(0.0);
+ }
+ }
+
+ float Lf[Br], Mf[Br];
+
+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
+
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Lf[r] = 0;
+ Mf[r] = NEG_FLT_MAX_OVER_2;
+ }
+
+ float slope[Br];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ slope[r] = 1.0;
+ }
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+ }
+ }
+
+#if BLOCK_SIZE > 1
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
+#else
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
+#endif
+
+ [[dont_unroll]]
+ for (uint32_t j = start_j; j < end_j; ++j) {
+
+ float Sf[Br][cols_per_thread];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ Sf[r][c] = 0.0;
+ }
+ }
+
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+ vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+#endif
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
+ }
+ }
+ }
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ // Compute sum across the D_split
+ [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
+ }
+ }
+ }
+
+ if (p.logit_softcap != 0.0f) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
+ }
+ }
+ }
+
+ if (p.mask != 0) {
+
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) % Bc;
+ uint32_t r = (idx + tid) / Bc;
+ if (idx + tid < Bc * Br) {
+ masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
+ }
+ }
+ barrier();
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ float mvf = masksh[c * cols_per_iter + col_tid][r];
+
+ Sf[r][c] += slope[r]*mvf;
+ }
+ }
+ barrier();
+ }
+
+ float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ rowmaxf[r] = Sf[r][0];
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
+ }
+ Moldf[r] = Mf[r];
+
+ // M = max(rowmax, Mold)
+ // P = e^(S - M)
+ // eM = e^(Mold - M)
+ Mf[r] = max(rowmaxf[r], Moldf[r]);
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ Pf[r][c] = exp(Sf[r][c] - Mf[r]);
+ }
+ eMf[r] = exp(Moldf[r] - Mf[r]);
+
+ // Compute sum across row of P
+ rowsumf[r] = 0.0;
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ rowsumf[r] += Pf[r][c];
+ }
+
+ Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
+ }
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] = eMf[r] * Of[r][d];
+ }
+ }
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+ vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+#endif
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] += Pf[r][c] * Vf;
+ }
+ }
+ }
+
+ barrier();
+ }
+
+ // reduce across threads
+
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ float rowmaxf, eMf;
+
+ tmpsh[tid] = Mf[r];
+ // Compute max across the row
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+ if (tid < s) {
+ tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
+ }
+ barrier();
+ }
+ rowmaxf = tmpsh[d_tid];
+ barrier();
+
+ float Moldf = Mf[r];
+
+ // M = max(rowmax, Mold)
+ // eM = e^(Mold - M)
+ Mf[r] = max(rowmaxf, Moldf);
+ eMf = exp(Moldf - Mf[r]);
+
+ Lf[r] = eMf*Lf[r];
+
+ tmpsh[tid] = Lf[r];
+
+ // Compute sum across the row
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+ if (tid < s) {
+ tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
+ }
+ barrier();
+ }
+ Lf[r] = tmpsh[d_tid];
+ barrier();
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+
+ Of[r][d] = eMf * Of[r][d];
+ tmpshv4[tid] = Of[r][d];
+
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+ if (tid < s) {
+ Of[r][d] += tmpshv4[tid + s];
+ tmpshv4[tid] = Of[r][d];
+ }
+ barrier();
+ }
+ Of[r][d] = tmpshv4[d_tid];
+ barrier();
+ }
+ }
+
+
+ // If there is split_k, then the split_k resolve shader does the final
+ // division by L. Store the intermediate O value and per-row m and L values.
+ if (p.k_num > 1) {
+ uint32_t o_offset = D * p.ne1 * split_k_index;
+
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (r < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+ }
+ }
+ }
+ }
+
+ o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (r < N) {
+ perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+ perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+ }
+ }
+
+ return;
+ }
+
+ float Lfrcp[Br];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Lfrcp[r] = 1.0 / Lf[r];
+ }
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] *= Lfrcp[r];
+ }
+ }
+
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
+
+ if (p.gqa_ratio > 1) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (r < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+ }
+ }
+ }
+ }
+ } else {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (i * Br + r < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp
new file mode 100644
index 00000000..61d90e2d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn_base.comp
@@ -0,0 +1,162 @@
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
+layout (constant_id = 1) const uint32_t Br = 1;
+layout (constant_id = 2) const uint32_t Bc = 32;
+layout (constant_id = 3) const uint32_t D = 32;
+layout (constant_id = 4) const uint32_t Clamp = 0;
+layout (constant_id = 5) const uint32_t D_split = 16;
+
+
+layout (push_constant) uniform parameter {
+ uint32_t N;
+ uint32_t KV;
+
+ uint32_t ne1;
+ uint32_t ne2;
+ uint32_t ne3;
+
+ uint32_t neq2;
+ uint32_t neq3;
+ uint32_t nek2;
+ uint32_t nek3;
+ uint32_t nev2;
+ uint32_t nev3;
+ uint32_t nem1;
+
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb03;
+ uint32_t nb11;
+ uint32_t nb12;
+ uint32_t nb13;
+ uint32_t nb21;
+ uint32_t nb22;
+ uint32_t nb23;
+ uint32_t nb31;
+
+ float scale;
+ float max_bias;
+ float logit_softcap;
+
+ uint32_t mask;
+ uint32_t n_head_log2;
+ float m0;
+ float m1;
+
+ uint32_t gqa_ratio;
+ uint32_t split_kv;
+ uint32_t k_num;
+} p;
+
+layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
+
+#if defined(A_TYPE_PACKED16)
+#define BINDING_IDX_K 0
+#define BINDING_IDX_V 1
+layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
+#endif
+
+#if defined(DATA_A_Q4_0)
+#define BLOCK_BYTE_SIZE 18
+
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+ uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
+ uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
+ uint shift = (iqs & 0x10) >> 2;
+ vui_lo >>= shift;
+ vui_hi >>= shift;
+
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+#define BLOCK_BYTE_SIZE 34
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+ const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
+ const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
+
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+
+// Store column zero. This is used to save per-row m and L values for split_k.
+ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ if (r < N && c == 0) {
+ uint32_t offset = iq2 + r;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ }
+ return elem;
+}
+
+// Load the slope matrix, indexed by Q's dimension 2.
+ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+ const uint32_t h = iq2 + (r % p.gqa_ratio);
+
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+ return ACC_TYPE(pow(base, ACC_TYPE(exph)));
+}
+
+uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
+ iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
+ q_stride, k_stride, v_stride, m_stride;
+
+void init_indices()
+{
+ N = p.N;
+ KV = p.KV;
+
+ i = gl_WorkGroupID.x;
+ split_k_index = 0;
+
+ if (p.k_num > 1) {
+ i = 0;
+ split_k_index = gl_WorkGroupID.x;
+ }
+
+ Tr = CEIL_DIV(N, Br);
+
+ start_j = split_k_index * p.split_kv / Bc;
+ end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
+
+ // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
+ // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
+ iq2 = gl_WorkGroupID.y * p.gqa_ratio;
+ iq3 = gl_WorkGroupID.z;
+
+ // broadcast factors
+ rk2 = p.neq2/p.nek2;
+ rk3 = p.neq3/p.nek3;
+
+ rv2 = p.neq2/p.nev2;
+ rv3 = p.neq3/p.nev3;
+
+ // k indices
+ ik3 = iq3 / rk3;
+ ik2 = iq2 / rk2;
+
+ // v indices
+ iv3 = iq3 / rv3;
+ iv2 = iq2 / rv2;
+
+ // nb?1 are already divided by the type size and are in units of elements.
+ // When using grouped query attention, Q is indexed by iq2, so the stride
+ // should be nb02 (which is in bytes).
+ q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
+ k_stride = p.nb11;
+ v_stride = p.nb21;
+ // When using grouped query attention, all rows use the same mask (stride 0).
+ // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
+ // that prevents the compiler from folding the "&" through the select
+ // and breaking the alignment detection.
+ m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
+}
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
new file mode 100644
index 00000000..da478be2
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
@@ -0,0 +1,360 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_cooperative_matrix : enable
+
+#include "types.comp"
+#include "flash_attn_base.comp"
+
+const uint32_t D_per_thread = D / D_split;
+const uint32_t row_split = 4;
+const uint32_t rows_per_thread = Br / row_split;
+const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
+const uint32_t cols_per_thread = Bc / cols_per_iter;
+
+
+layout (binding = 0) readonly buffer Q {float data_q[];};
+layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
+layout (binding = 1) readonly buffer K {float16_t data_k[];};
+layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
+layout (binding = 2) readonly buffer V {float16_t data_v[];};
+layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
+layout (binding = 3) readonly buffer M {float16_t data_m[];};
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ uint32_t offset = (iq2 + r) * D + c;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ return elem;
+}
+
+// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
+const uint32_t MatBr = 16;
+const uint32_t MatBc = 16;
+
+shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
+shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
+
+const uint32_t qstride = D / 4 + 2; // in units of f16vec4
+shared f16vec4 Qf[Br * qstride];
+
+// Avoid padding for D==256 to make it fit in 48KB shmem.
+const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
+shared ACC_TYPE sfsh[Bc * sfshstride];
+
+const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
+shared f16vec4 ksh[Bc * kshstride];
+
+shared float slope[Br];
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ init_indices();
+
+ const uint32_t tid = gl_LocalInvocationIndex;
+
+ const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+ const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
+ const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
+ const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
+
+#define tile_row(r) (row_tid * rows_per_thread + (r))
+
+ uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+
+ [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (D / 4);
+ uint32_t r = (idx + tid) / (D / 4);
+ if (r < Br && d < D / 4 &&
+ i * Br + r < N) {
+ Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
+ }
+ }
+ barrier();
+
+ ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Of[r][d] = ACC_TYPEV4(0.0);
+ }
+ }
+
+ float Lf[rows_per_thread], Mf[rows_per_thread];
+
+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
+
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Lf[r] = 0;
+ Mf[r] = NEG_FLT_MAX_OVER_2;
+ }
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ if (tid < Br) {
+ uint r = tid;
+ slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+ }
+ barrier();
+ } else {
+ if (tid < Br) {
+ uint r = tid;
+ slope[r] = 1.0;
+ }
+ barrier();
+ }
+
+#if BLOCK_SIZE > 1
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
+#else
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
+#endif
+
+ [[dont_unroll]]
+ for (uint32_t j = start_j; j < end_j; ++j) {
+
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (D / 4);
+ uint32_t c = (idx + tid) / (D / 4);
+ if (c < Bc && d < D / 4) {
+#if BLOCK_SIZE > 1
+ uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+#else
+ f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+#endif
+
+ ksh[c * kshstride + d] = K_Tf;
+ }
+ }
+ barrier();
+
+ // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
+ // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
+ // This is written transposed in order to allow for N being 8 if implementations need it
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
+ coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
+ coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
+
+ for (uint32_t d = 0; d < D / 16; ++d) {
+ coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
+
+ uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
+ coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+
+ SfMat = coopMatMulAdd(KMat, QMat, SfMat);
+ }
+
+ uint coord = gl_SubgroupID * MatBc * sfshstride;
+ coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
+ barrier();
+
+ if (p.logit_softcap != 0.0f) {
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) / Br;
+ uint32_t r = (idx + tid) % Br;
+ if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
+ sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
+ }
+ }
+ barrier();
+ }
+
+ if (p.mask != 0) {
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) % Bc;
+ uint32_t r = (idx + tid) / Bc;
+ if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
+ }
+ }
+ barrier();
+ }
+
+ float eMf[rows_per_thread];
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
+ }
+ float Moldf = Mf[r];
+
+ // M = max(rowmax, Mold)
+ // P = e^(S - M)
+ // eM = e^(Mold - M)
+ Mf[r] = max(rowmaxf, Moldf);
+ eMf[r] = exp(Moldf - Mf[r]);
+ }
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Of[r][d] = float16_t(eMf[r]) * Of[r][d];
+ }
+ }
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Lf[r] = eMf[r]*Lf[r];
+ }
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ float Pf[rows_per_thread];
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
+ Lf[r] += Pf[r];
+ }
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+ vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+#endif
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
+ }
+ }
+ }
+
+ barrier();
+ }
+
+ // reduce across threads
+
+ float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ FLOAT_TYPE M = Mf[r];
+ tmpsh[tid] = M;
+ // Compute max across the row
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
+ M = max(M, tmpsh[tid ^ s]);
+ barrier();
+ tmpsh[tid] = M;
+ barrier();
+ }
+ rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
+ barrier();
+ }
+
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Moldf[r] = Mf[r];
+
+ // M = max(rowmax, Mold)
+ // eM = e^(Mold - M)
+ Mf[r] = max(rowmaxf[r], Moldf[r]);
+ eMf[r] = exp(Moldf[r] - Mf[r]);
+
+ Lf[r] = eMf[r]*Lf[r];
+ }
+
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ FLOAT_TYPE L = Lf[r];
+ tmpsh[tid] = L;
+ // Compute sum across the row
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
+ L += tmpsh[tid ^ s];
+ barrier();
+ tmpsh[tid] = L;
+ barrier();
+ }
+ Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
+ barrier();
+ }
+
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+
+ Of[r][d] = float16_t(eMf[r]) * Of[r][d];
+ tmpshv4[tid] = Of[r][d];
+
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
+ Of[r][d] += tmpshv4[tid ^ s];
+ barrier();
+ tmpshv4[tid] = Of[r][d];
+ barrier();
+ }
+ Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
+ barrier();
+ }
+ }
+
+ // If there is split_k, then the split_k resolve shader does the final
+ // division by L. Store the intermediate O value and per-row m and L values.
+ if (p.k_num > 1) {
+ uint32_t o_offset = D * p.ne1 * split_k_index;
+
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ if (tile_row(r) < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
+ }
+ }
+ }
+ }
+
+ o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ if (tile_row(r) < N) {
+ perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+ perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+ }
+ }
+
+ return;
+ }
+
+ float Lfrcp[rows_per_thread];
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Lfrcp[r] = 1.0 / Lf[r];
+ }
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ Of[r][d] *= float16_t(Lfrcp[r]);
+ }
+ }
+
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
+
+ if (p.gqa_ratio > 1) {
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ if (tile_row(r) < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
+ }
+ }
+ }
+ }
+ } else {
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ if (i * Br + tile_row(r) < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
new file mode 100644
index 00000000..6acf67a0
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
@@ -0,0 +1,267 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_EXT_buffer_reference : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_vote : enable
+#extension GL_EXT_null_initializer : enable
+
+#include "types.comp"
+#include "dequant_funcs_cm2.comp"
+#include "flash_attn_base.comp"
+
+layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
+layout (binding = 1) readonly buffer K {uint8_t data_k[];};
+layout (binding = 2) readonly buffer V {uint8_t data_v[];};
+layout (binding = 3) readonly buffer M {uint8_t data_m[];};
+
+ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
+ return max(x, y);
+}
+
+ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
+ return x;
+}
+
+// Replace matrix elements >= numRows or numCols with 'replace'
+ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
+ if (row >= numRows || col >= numCols) {
+ return replace;
+ }
+ return elem;
+}
+
+ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
+{
+ return exp(elem);
+}
+
+ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
+{
+ return max(elem0, elem1);
+}
+
+#if defined(BLOCK_SIZE)
+#define DECODEFUNC , DEQUANTFUNC
+#else
+#define DECODEFUNC
+#endif
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ if (r < N && c < D) {
+ uint32_t offset = (iq2 + r) * D + c;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ }
+ return elem;
+}
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ init_indices();
+
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
+ tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
+
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
+
+#if defined(BLOCK_SIZE)
+ tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
+ tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
+#endif
+
+ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
+ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
+ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
+
+ // hint to the compiler that strides are aligned for the aligned variant of the shader
+ if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
+ {
+ q_stride &= ~7;
+#if !defined(BLOCK_SIZE)
+ k_stride &= ~7;
+ v_stride &= ~7;
+#endif
+ m_stride &= ~7;
+ }
+ tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
+ tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
+ tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
+
+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
+
+ uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
+
+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
+ Qf16 *= float16_t(p.scale);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
+
+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
+
+ L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+ M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
+ }
+
+ [[dont_unroll]]
+ for (uint32_t j = start_j; j < end_j; ++j) {
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
+
+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
+ S = coopMatMulAdd(Qf16, K_T, S);
+
+ if (p.logit_softcap != 0.0f) {
+ [[unroll]]
+ for (int k = 0; k < S.length(); ++k) {
+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
+ }
+ }
+
+ if (p.mask != 0) {
+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+
+ coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+ }
+
+ // Clear padding elements to -inf, so they don't contribute to rowmax
+ if (Clamp != 0 &&
+ ((j + 1) * Bc > KV ||
+ (i + 1) * Br > N)) {
+
+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
+
+ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
+ }
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
+
+ coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
+
+ // M = max(rowmax, Mold)
+ // P = e^(S - M)
+ // eM = e^(Mold - M)
+ coopMatPerElementNV(M, rowmax, Max, Mold);
+ coopMatPerElementNV(P, S - M, Exp);
+ coopMatPerElementNV(eM, Mold - M, Exp);
+
+ // Clear padding elements to 0, so they don't contribute to rowsum
+ if (Clamp != 0 &&
+ ((j + 1) * Bc > KV ||
+ (i + 1) * Br > N)) {
+
+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
+
+ coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
+ }
+
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
+
+ // compute rowsum by multiplying by matrix of all ones.
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
+
+ rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
+ rowsum = coopMatMulAdd(P_A, One, rowsum);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
+ uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
+
+ L = eM*L + rowsum;
+
+ // This is the "diagonal" matrix in the paper, but since we do componentwise
+ // multiply rather than matrix multiply it has the diagonal element smeared
+ // across the row
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
+
+ // resize eM by using smear/reduce
+ coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
+
+ // multiply with fp16 accumulation, then add to O.
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+ PV = coopMatMulAdd(P_A, V, PV);
+
+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
+ }
+
+ // If there is split_k, then the split_k resolve shader does the final
+ // division by L. Store the intermediate O value and per-row m and L values.
+ if (p.k_num > 1) {
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+
+ uint32_t o_offset = D * p.ne1 * split_k_index;
+ coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+
+ o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
+ coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+ return;
+ }
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
+
+ // resize L by using smear/reduce
+ coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
+
+ [[unroll]]
+ for (int k = 0; k < Ldiag.length(); ++k) {
+ Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
+ }
+
+ O = Ldiag*O;
+
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
+
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+ if (p.gqa_ratio > 1) {
+ coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+ } else {
+ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
+
+ // permute dimensions
+ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
+
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp
new file mode 100644
index 00000000..a7e39568
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -0,0 +1,59 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#define BLOCK_SIZE 32
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float data_a[];};
+layout (binding = 1) writeonly buffer D {float data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint D;
+ uint N;
+ uint k_num;
+} p;
+
+void main() {
+ // Each workgroup handles a row
+ const uint n = gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ uint D = p.D;
+ uint N = p.N;
+ uint k_num = p.k_num;
+
+ uint l_offset = D * N * k_num + n;
+ uint m_offset = D * N * k_num + N + n;
+ uint lm_stride = N * 2;
+
+ // Compute the max m value for the row
+ float m_max = -1.0/0.0;
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
+ float m = data_a[m_offset + k * lm_stride];
+ m_max = max(m_max, m);
+ }
+
+ // Compute L based on m_max
+ float L = 0;
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
+ float l = data_a[l_offset + k * lm_stride];
+ float m = data_a[m_offset + k * lm_stride];
+ L += exp(m - m_max) * l;
+ }
+
+ L = 1.0 / L;
+
+ // Scale and sum the O contributions based on m_max and store the result to memory
+ for (uint d = tid; d < D; d += BLOCK_SIZE) {
+ float O = 0.0;
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
+ uint o_offset = D * N * k + D * n + d;
+ float m = data_a[m_offset + k * lm_stride];
+ O += exp(m - m_max) * data_a[o_offset];
+ }
+ O *= L;
+ data_d[D * n + d] = O;
+ }
+}
diff --git a/ggml/src/vulkan-shaders/generic_binary_head.comp b/ggml/src/vulkan-shaders/generic_binary_head.comp
index b6beaff1..062e2a4c 100644
--- a/ggml/src/vulkan-shaders/generic_binary_head.comp
+++ b/ggml/src/vulkan-shaders/generic_binary_head.comp
@@ -1,4 +1,5 @@
#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : require
layout (push_constant) uniform parameter
{
@@ -6,47 +7,58 @@ layout (push_constant) uniform parameter
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
- uint d_offset;
+ uint misalign_offsets;
float param1; float param2; int param3;
} p;
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+// true if src0/src1 are the same shape and the indices can be reused without additional modulus
+layout(constant_id = 0) const bool norepeat = false;
+
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
-uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
- return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+uint get_aoffset() { return p.misalign_offsets >> 16; }
+uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
+uint get_doffset() { return p.misalign_offsets & 0xFF; }
+
+// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
+uint fastmod(uint a, uint b) {
+ if ((b & (b-1)) == 0) {
+ return a & (b-1);
+ }
+ return a % b;
}
-uint src1_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+uint fastdiv(uint a, uint b) {
+ return (a < b) ? 0 : (a / b);
+}
+
+void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
+ i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+}
+
+uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
- return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
+uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
+ if (norepeat) {
+ return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
+ } else {
+ return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
+ }
}
-uint dst_idx(uint idx) {
- const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
- const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
- const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
- const uint i22_offset = i22*p.ne21*p.ne20;
- const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
- const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
- return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
+uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
+ return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
}
diff --git a/ggml/src/vulkan-shaders/generic_unary_head.comp b/ggml/src/vulkan-shaders/generic_unary_head.comp
index eacdefc7..8dc9d360 100644
--- a/ggml/src/vulkan-shaders/generic_unary_head.comp
+++ b/ggml/src/vulkan-shaders/generic_unary_head.comp
@@ -1,15 +1,21 @@
#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : require
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
- uint d_offset;
+ uint misalign_offsets;
float param1; float param2;
-} p;
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+ uint ne0_012mp; uint ne0_012L;
+ uint ne0_01mp; uint ne0_01L;
+ uint ne0_0mp; uint ne0_0L;
+ uint ne1_012mp; uint ne1_012L;
+ uint ne1_01mp; uint ne1_01L;
+ uint ne1_0mp; uint ne1_0L;
+} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
@@ -18,22 +24,53 @@ uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
+uint get_aoffset() { return p.misalign_offsets >> 16; }
+uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
+
+// see init_fastdiv_values in ggml-vulkan.cpp
+uint fastdiv(uint n, uint mp, uint L) {
+ uint msbs, lsbs;
+ // msbs = mulhi(n, mp)
+ umulExtended(n, mp, msbs, lsbs);
+ return (msbs + n) >> L;
+}
+
uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
uint dst_idx(uint idx) {
- const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
+ const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
- const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
+ const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
- const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
+ const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
}
+
+uint src0_idx_quant(uint idx, uint qk) {
+ const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00;
+}
+
+uint dst_idx_quant(uint idx, uint qk) {
+ const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+ const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
+ const uint i12_offset = i12*p.ne11*p.ne10;
+ const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+ return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;
+}
diff --git a/ggml/src/vulkan-shaders/get_rows.comp b/ggml/src/vulkan-shaders/get_rows.comp
index e9ff22ef..ee6b86a1 100644
--- a/ggml/src/vulkan-shaders/get_rows.comp
+++ b/ggml/src/vulkan-shaders/get_rows.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_binary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint i00 = gl_GlobalInvocationID.x;
const uint i10 = gl_GlobalInvocationID.y;
@@ -13,14 +15,19 @@ void main() {
return;
}
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+ const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+ const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+#if defined(DATA_A_BF16)
+ FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
+#else
+ FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
+#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
+ data_d[d_offset + i00] = D_TYPE(v);
#else
- data_d[d_offset + i00] = data_a[a_offset + i00];
+ data_d[d_offset + i00] = D_TYPE(v);
#endif
}
diff --git a/ggml/src/vulkan-shaders/get_rows_quant.comp b/ggml/src/vulkan-shaders/get_rows_quant.comp
index 53a9a96f..cfd645a3 100644
--- a/ggml/src/vulkan-shaders/get_rows_quant.comp
+++ b/ggml/src/vulkan-shaders/get_rows_quant.comp
@@ -1,15 +1,23 @@
#version 450
+#extension GL_EXT_control_flow_attributes : enable
+
#include "types.comp"
#include "generic_binary_head.comp"
#include "dequant_funcs.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
const uint i10 = gl_GlobalInvocationID.y;
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
if (i00 >= p.ne00) {
return;
}
@@ -25,6 +33,8 @@ void main() {
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
vec2 v = dequantize(ib, iqs, 0);
+ const vec2 dm = get_dm(ib, 0);
+ v = v * dm.x + dm.y;
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
diff --git a/ggml/src/vulkan-shaders/group_norm.comp b/ggml/src/vulkan-shaders/group_norm.comp
index 5ad9b28d..b6a0d564 100644
--- a/ggml/src/vulkan-shaders/group_norm.comp
+++ b/ggml/src/vulkan-shaders/group_norm.comp
@@ -19,7 +19,7 @@ void main() {
const uint tid = gl_LocalInvocationID.x;
const uint start = gl_WorkGroupID.x * group_size + tid;
- const uint end = start + group_size;
+ const uint end = (gl_WorkGroupID.x + 1) * group_size;
tmp[tid] = 0.0f;
diff --git a/ggml/src/vulkan-shaders/im2col.comp b/ggml/src/vulkan-shaders/im2col.comp
index 4d48610a..09aa849e 100644
--- a/ggml/src/vulkan-shaders/im2col.comp
+++ b/ggml/src/vulkan-shaders/im2col.comp
@@ -1,6 +1,12 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_spirv_intrinsics: enable
+#extension GL_EXT_control_flow_attributes : require
+
+#if RTE16
+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
+#endif
layout (push_constant) uniform parameter
{
@@ -18,40 +24,77 @@ layout (push_constant) uniform parameter
#include "types.comp"
-#define BLOCK_SIZE 256
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+
+const uint NUM_ITER = 512 / BLOCK_SIZE;
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
- const uint i = gl_GlobalInvocationID.x;
- if (i >= p.pelements) {
- return;
- }
-
- const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
- const uint kx = i / ksize;
- const uint kd = kx * ksize;
- const uint ky = (i - kd) / p.OW;
- const uint ix = i % p.OW;
+ const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
- const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
- const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
+ const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
+ const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
+ const int oh_s1 = int(oh) * p.s1;
+ const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+
+ const uint base_linear_idx = gidx * NUM_ITER;
+
+ const uint max_ky = ksize / p.OW;
+
+ uint current_kx = base_linear_idx / ksize;
+ const uint rem = base_linear_idx - (current_kx * ksize);
+ uint current_ky = rem / p.OW;
+ uint current_ix = rem % p.OW;
+
+ A_TYPE values[NUM_ITER];
+ uint offset_dst[NUM_ITER];
+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+ values[idx] = A_TYPE(0);
+ }
+
+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+
+ const uint linear_idx = base_linear_idx + idx;
+
+ if (linear_idx >= p.pelements) {
+ continue;
+ }
- const uint offset_dst =
- ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
- (ic * (p.KW * p.KH) + ky * p.KW + kx);
+ const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
+ const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
- if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
- data_d[offset_dst] = D_TYPE(0.0f);
- } else {
- const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
- data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
+ offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
+
+ if ((iih < p.IH) && (iiw < p.IW)) {
+ values[idx] = data_a[src_base + iih * p.IW + iiw];
+ }
+
+ if (++current_ix == p.OW) {
+ current_ix = 0;
+ if (++current_ky == max_ky) {
+ current_ky = 0;
+ current_kx++;
+ }
+ }
}
+
+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+
+ const uint linear_idx = base_linear_idx + idx;
+
+ if (linear_idx >= p.pelements) {
+ continue;
+ }
+
+ data_d[offset_dst[idx]] = D_TYPE(values[idx]);
+ }
+
}
diff --git a/ggml/src/vulkan-shaders/mul.comp b/ggml/src/vulkan-shaders/mul.comp
index bfb61c92..43de19df 100644
--- a/ggml/src/vulkan-shaders/mul.comp
+++ b/ggml/src/vulkan-shaders/mul.comp
@@ -3,12 +3,25 @@
#include "types.comp"
#include "generic_binary_head.comp"
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp b/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp
index 825b9103..4c64fd47 100644
--- a/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp
@@ -5,7 +5,9 @@
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
+layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
+layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
layout (push_constant) uniform parameter {
uint ne;
@@ -13,17 +15,34 @@ layout (push_constant) uniform parameter {
} p;
void main() {
- const uint idx = gl_GlobalInvocationID.x;
+ // Each invocation handles four consecutive components
+ const uint idx = gl_GlobalInvocationID.x * 4;
if (idx >= p.ne) {
return;
}
- float result = 0.0f;
+ // Check if all four components are in bounds and aligned,
+ // then use vector loads
+ if (idx + 3 < p.ne && (p.ne % 4) == 0) {
+ vec4 result = vec4(0.0f);
- [[unroll]] for (uint i = 0; i < p.k_num; i++) {
- result += data_a[i * p.ne + idx];
- }
+ [[unroll]] for (uint i = 0; i < p.k_num; i++) {
+ result += data_a4[(i * p.ne + idx) / 4];
+ }
+
+ data_d4[idx / 4] = result;
+ } else {
+ [[unroll]] for (uint j = 0; j < 4; ++j) {
+ if (idx + j < p.ne) {
+ float result = 0.0f;
- data_d[idx] = result;
+ [[unroll]] for (uint i = 0; i < p.k_num; i++) {
+ result += data_a[i * p.ne + idx + j];
+ }
+
+ data_d[idx + j] = result;
+ }
+ }
+ }
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec.comp b/ggml/src/vulkan-shaders/mul_mat_vec.comp
index 46a6369b..bb429dd5 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec.comp
@@ -1,57 +1,169 @@
#version 450
-#ifdef FLOAT16
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#endif
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
+#define K_PER_ITER 8
+#else
+#define K_PER_ITER 2
+#endif
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
- const uint tid = gl_LocalInvocationID.x;
+uint a_offset, b_offset, d_offset, y_offset;
+
+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
+{
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = col - col%QUANT_K; // y block start index
+
+#if K_PER_ITER == 8
+#if QUANT_R == 2
+ const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+ const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
+ const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
+ const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
+#else
+ const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+ const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
+#endif
+#else
+ // Check if the second of the pair of elements is OOB, and don't fetch B or
+ // accumulate it. We still fetch a pair of elements for A, which is fine for
+ // quantized formats since they'll be within the same block. We should
+ // probably skip fetching the second element for F16/F32, but as of now we
+ // still do.
+ const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
+
+ FLOAT_TYPE b0 = 0, b1 = 0;
+ b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
+ if (!OOB) {
+ b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
+ }
+#endif
+ uint ibi = first_row*p.ncols;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib = (ibi + col)/QUANT_K; // block index
+ ibi += p.ncols;
+
+#if K_PER_ITER == 8
+ vec4 v = dequantize4(ib, iqs, a_offset);
+ vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
+
+ const vec2 dm = get_dm(ib, a_offset);
+ if (dm.y != 0) { // quant has min component
+ v = v * dm.x + dm.y;
+ v2 = v2 * dm.x + dm.y;
+ }
+
+ // matrix multiplication
+ FLOAT_TYPE rowtmp = dot(bv0, v);
+ rowtmp += dot(bv1, v2);
+
+ if (dm.y == 0)
+ rowtmp *= dm.x;
+
+ temp[j][n] += rowtmp;
+#else
+ const vec2 v = dequantize(ib, iqs, a_offset);
- // There are not enough cols to use all threads
- if (tid >= p.ncols) {
- return;
+ // matrix multiplication
+ temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
+ if (!OOB) {
+ temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
+ }
+#endif
+ }
}
+}
- const uint block_size = min(p.ncols, BLOCK_SIZE);
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ const uint tid = gl_LocalInvocationID.x;
- uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
+ a_offset /= QUANT_K;
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+ y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
- tmp[tid] = FLOAT_TYPE(0.0f);
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
- const uint col = i*block_size + 2*tid;
- const uint ib = (row*p.ncols + col)/QUANT_K; // block index
- const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
- const uint iybs = col - col%QUANT_K; // y block start index
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
- vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
+ num_iters++;
+ }
+ int unroll_count = 4;
+ uint unrolled_iters = num_iters & ~(unroll_count - 1);
- // matrix multiplication
- tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
- FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
+#if K_PER_ITER == 2
+ // If the K dimension is odd, we need lastiter==true on the last iteration
+ // so OOB is computed correctly. Skip some unrolling to make that happen.
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
}
+#endif
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
+ uint i = 0;
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
+ i++;
}
- barrier();
}
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
+
+ unroll_count = 2;
+ unrolled_iters = num_iters & ~(unroll_count - 1);
+
+#if K_PER_ITER == 2
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
+ }
+#endif
+
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
+ i++;
+ }
+ }
+ while (i < num_iters) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
+ i++;
+ }
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp
index 5920bc93..903753c7 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp
@@ -2,8 +2,6 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
-#define K_QUANTS_PER_ITERATION 2
-
#ifdef MUL_MAT_ID
#define EXPERT_COUNT 8
#endif
@@ -12,6 +10,9 @@
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
@@ -49,13 +50,16 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#endif
#ifndef MUL_MAT_ID
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
+ uint batch_idx_a = 0;
+ if (batch_idx != 0) {
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
- const uint batch_idx_a = i03 * p.ne02 + i02;
+ batch_idx_a = i03 * p.ne02 + i02;
+ }
#else
const uint expert_id = data_ids[expert_idx];
#endif
@@ -79,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
batch_idx * p.batch_stride_d;
#endif
}
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
+layout (constant_id = 2) const uint NUM_COLS = 1;
+
+shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
+
+void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
+ // sum up partial sums and write back result
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ tmpsh[j][n][tid] = temp[j][n];
+ }
+ }
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
+ }
+ }
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
+ }
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq1_m.comp
new file mode 100644
index 00000000..e4acbd4f
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq1_m.comp
@@ -0,0 +1,82 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 32 * ib32;
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint16_t[4] scales = data_a[ibi].scales;
+ const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
+ const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
+
+ const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
+ const uint qs = data_a[ibi].qs[4 * ib32 + l];
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
+ const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
+
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ [[unroll]] for (int k = 0; k < 4; ++k) {
+ sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
+ fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
+ }
+ temp[j][n] = fma(dl, sum, temp[j][n]);
+ }
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 8 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/8;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 8; // 0...7
+ const uint ix = tid / 8;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq1_s.comp
new file mode 100644
index 00000000..309da099
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq1_s.comp
@@ -0,0 +1,79 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 32 * ib32;
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const float d = float(data_a[ibi].d);
+ const uint qh = data_a[ibi].qh[ib32];
+ const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ const uint qs = data_a[ibi].qs[4 * ib32 + l];
+ const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
+ const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ [[unroll]] for (int k = 0; k < 4; ++k) {
+ sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
+ fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
+ }
+ temp[j][n] = fma(dl, sum, temp[j][n]);
+ }
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 8 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/8;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 8; // 0...7
+ const uint ix = tid / 8;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq2_s.comp
new file mode 100644
index 00000000..8d01536f
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq2_s.comp
@@ -0,0 +1,90 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 16 * itid;
+ const uint nibble_shift = 4 * (itid & 1);
+ const uint ib32 = itid / 2; // 0..7
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const float d = float(data_a[ibi].d);
+ const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
+ const float db = d * (0.5 + scale) * 0.25;
+
+ const uint qh = data_a[ibi].qh[ib32];
+ const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
+ const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
+ [[unroll]] for (uint l = 0; l < 2; ++l) {
+ const uint8_t sign = sign16[l];
+ const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);
+ const uvec2 grid = iq2s_grid[qs];
+ const vec4 grid0 = vec4(unpack8(grid.x));
+ const vec4 grid1 = vec4(unpack8(grid.y));
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ FLOAT_TYPE sum =
+ fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
+ fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
+ fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
+ fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
+ fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
+ fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
+ fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
+ fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
+ FLOAT_TYPE(0.0)))))))));
+ temp[j][n] = fma(db, sum, temp[j][n]);
+ }
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 16 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 16; // 0...15
+ const uint ix = tid / 16;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq2_xs.comp
new file mode 100644
index 00000000..c4960432
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq2_xs.comp
@@ -0,0 +1,87 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 16 * itid;
+ const uint nibble_shift = 4 * (itid & 1);
+ const uint ib32 = itid / 2; // 0..7
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const float d = float(data_a[ibi].d);
+ const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
+ const float db = d * (0.5 + scale) * 0.25;
+
+ [[unroll]] for (uint l = 0; l < 2; ++l) {
+ const uint qs = data_a[ibi].qs[2 * itid + l];
+ const uint sign = qs >> 9;
+ const uint sign7 = bitCount(sign);
+ const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
+ const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ FLOAT_TYPE sum =
+ fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
+ fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
+ fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
+ fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
+ fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
+ fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
+ fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
+ fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
+ FLOAT_TYPE(0.0)))))))));
+ temp[j][n] = fma(db, sum, temp[j][n]);
+ }
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 16 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 16; // 0...15
+ const uint ix = tid / 16;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
new file mode 100644
index 00000000..94d4b92e
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
@@ -0,0 +1,87 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 16 * itid;
+ const uint ib32 = itid / 2; // 0..7
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const float d = float(data_a[ibi].d);
+ const uint signscale = pack32(u16vec2(
+ data_a_packed16[ibi].qs[4 * ib32 + 2],
+ data_a_packed16[ibi].qs[4 * ib32 + 3]));
+ const float db = d * 0.25 * (0.5 + (signscale >> 28));
+ [[unroll]] for (uint l = 0; l < 2; ++l) {
+ const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l];
+ const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
+ const uint sign7 = bitCount(sign);
+ const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x));
+ const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y));
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ FLOAT_TYPE sum =
+ fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
+ fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
+ fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
+ fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
+ fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
+ fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
+ fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
+ fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
+ FLOAT_TYPE(0.0)))))))));
+ temp[j][n] = fma(db, sum, temp[j][n]);
+ }
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 16 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 16; // 0...15
+ const uint ix = tid / 16;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq3_s.comp
new file mode 100644
index 00000000..f021e404
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq3_s.comp
@@ -0,0 +1,90 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 32 * ib32;
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const float d = float(data_a[ibi].d);
+ const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+ const float dscale = d * (1 + 2 * scale);
+ const uint qh = data_a[ibi].qh[ib32];
+ FLOAT_TYPE sum[NUM_COLS];
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ sum[j] = 0.0;
+ }
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
+ const uint sign = data_a[ibi].signs[4 * ib32 + l];
+ const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
+ const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ sum[j] =
+ fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
+ fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
+ fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
+ fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
+ fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
+ fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
+ fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
+ fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
+ sum[j]))))))));
+ }
+ }
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ temp[j][n] = fma(dscale, sum[j], temp[j][n]);
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 8 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/8;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 8; // 0...7
+ const uint ix = tid / 8;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
new file mode 100644
index 00000000..3fe9dc3a
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
@@ -0,0 +1,88 @@
+#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y_idx = i * QUANT_K + 16 * itid;
+ const uint ib32 = itid / 2; // 0..7
+
+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const float d = float(data_a[ibi].d);
+ const uint signscale = pack32(u16vec2(
+ data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32],
+ data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1]));
+ const float db = d * 0.5 * (0.5 + (signscale >> 28));
+ [[unroll]] for (uint l = 0; l < 2; ++l) {
+ const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l];
+ const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1];
+ const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
+ const uint sign7 = bitCount(sign);
+ const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0]));
+ const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1]));
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
+ const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
+
+ FLOAT_TYPE sum =
+ fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
+ fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
+ fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
+ fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
+ fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
+ fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
+ fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
+ fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
+ FLOAT_TYPE(0.0)))))))));
+ temp[j][n] = fma(db, sum, temp[j][n]);
+ }
+ }
+ ibi += num_blocks_per_row;
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+
+ // 16 threads are used to process each block
+ const uint blocks_per_wg = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid % 16; // 0...15
+ const uint ix = tid / 16;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
+ calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp
index cb3f3c0d..bc633369 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp
@@ -12,13 +12,18 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
layout (push_constant) uniform parameter
{
uint ncols_x;
uint nrows_x;
uint row_stride_x;
uint channel_stride_x;
+ uint channel_stride_y;
uint channel_x_divisor;
+ uint ne12;
uint b_offset;
uint d_offset;
} p;
@@ -30,6 +35,7 @@ void main() {
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint channel_x = channel / p.channel_x_divisor;
+ const uint channel_y = channel % p.ne12;
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
@@ -37,25 +43,66 @@ void main() {
const uint idst = channel*nrows_dst + row_dst;
- tmp[tid] = 0.0f;
+ FLOAT_TYPE temp = 0.0f;
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
+ // Detect alignment for vector loads
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
- if (col_x >= p.ncols_x) {
- break;
- }
+ for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
+
+ // Unroll 2x and do vec4 loads if aligned
+ const uint unroll_count = 2;
+ if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
+ [[unroll]] for (uint i = 0; i < unroll_count; ++i) {
+ const uint col_x = col_x0 + 4*tid;
+
+ const uint row_y = col_x;
+
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel_y*p.channel_stride_y + row_y;
+
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
+ const vec4 bv4 = vec4(data_b_v4[iy / 4]);
+
+ temp += dot(av4, bv4);
+
+ col_x0 += 4*BLOCK_SIZE;
+ }
+ // do vec4 loads if aligned
+ } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
+ const uint col_x = col_x0 + 4*tid;
- const uint row_y = col_x;
+ const uint row_y = col_x;
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
- const uint iy = channel*nrows_y + row_y;
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel_y*p.channel_stride_y + row_y;
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
+ const vec4 bv4 = vec4(data_b_v4[iy / 4]);
- tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+ temp += dot(av4, bv4);
+
+ col_x0 += 4*BLOCK_SIZE;
+ } else {
+ const uint col_x = col_x0 + tid;
+ if (col_x >= p.ncols_x) {
+ break;
+ }
+
+ const uint row_y = col_x;
+
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel_y*p.channel_stride_y + row_y;
+
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+
+ temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
+ col_x0 += BLOCK_SIZE;
+ }
}
+ tmp[tid] = temp;
+
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp
index 4b1871ca..7aa070ee 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp
@@ -2,16 +2,25 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
-#define BLOCK_SIZE 32
#define FLOAT_TYPE float
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
+layout(constant_id = 0) const int BLOCK_SIZE = 32;
+// gqa_ratio is in the range [1,8]
+layout(constant_id = 1) const uint gqa_ratio = 1;
+
layout (push_constant) uniform parameter
{
uint ncols_x;
@@ -22,52 +31,124 @@ layout (push_constant) uniform parameter
uint d_offset;
} p;
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
+#if !USE_SUBGROUP_ADD
+shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
+#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
- const uint channel = gl_GlobalInvocationID.z;
- const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
+
+ uint channel, channel_x;
+
+ // When gqa_ratio > 1, each invocation does multiple rows.
+ // The row in the A matrix is starting from channel / gqa_ratio and the
+ // rows in the B matrix are [channel, channel+gqa_ratio).
+ // When gpa_ratio is 1, each invocation does one row.
+ if (gqa_ratio > 1) {
+ channel_x = gl_GlobalInvocationID.z;
+ channel = channel_x * gqa_ratio;
+ } else {
+ channel = gl_GlobalInvocationID.z;
+ channel_x = channel / (p.nchannels_y / p.nchannels_x);;
+ }
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
- tmp[tid] = FLOAT_TYPE(0.0f);
+ FLOAT_TYPE temp[8];
+ [[unroll]] for (uint i = 0; i < 8; ++i) {
+ temp[i] = FLOAT_TYPE(0.0f);
+ }
+
+ // Detect alignment for vector loads
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
- if (col_x >= p.ncols_x) {
- break;
- }
+ // Use vec4 loads if aligned
+ if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
- // x is transposed and permuted
- const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+ uint col_x = col_x0 + 4*tid;
+ const uint row_y = col_x;
- const uint row_y = col_x;
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
- // y is not transposed but permuted
- const uint iy = channel*nrows_y + row_y;
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // y is not transposed but permuted
+ const uint iy = (channel + c)*nrows_y + row_y;
- tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
- }
+ vec4 bv4 = data_b_v4[iy / 4];
+ temp[c] += dot(av4, bv4);
+ }
+
+ col_x0 += 3*BLOCK_SIZE;
+ } else {
+ const uint col_x = col_x0 + tid;
+
+ if (col_x >= p.ncols_x) {
+ break;
+ }
- // dst is not transposed and not permuted
- const uint idst = channel*nrows_dst + row_dst;
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+ const uint row_y = col_x;
+
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // y is not transposed but permuted
+ const uint iy = (channel + c)*nrows_y + row_y;
+
+ temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
+ }
+ }
+ }
+
+#if USE_SUBGROUP_ADD
+ // reduce vec4 at a time
+ vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
+ t = subgroupAdd(t);
+ temp[0] = t[0];
+ temp[1] = t[1];
+ temp[2] = t[2];
+ temp[3] = t[3];
+ if (gqa_ratio > 4) {
+ t = vec4(temp[4], temp[5], temp[6], temp[7]);
+ t = subgroupAdd(t);
+ temp[4] = t[0];
+ temp[5] = t[1];
+ temp[6] = t[2];
+ temp[7] = t[3];
+ }
+#else
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ tmp[c][tid] = temp[c];
+ }
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
- tmp[tid] += tmp[tid + s];
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ temp[c] += tmp[c][tid + s];
+ tmp[c][tid] = temp[c];
+ }
}
barrier();
}
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ temp[c] = tmp[c][tid];
+ }
+#endif
if (tid == 0) {
- dst[idst] = tmp[0];
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // dst is not transposed and not permuted
+ const uint idst = (channel + c)*nrows_dst + row_dst;
+ dst[idst] = temp[c];
+ }
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
index 4cd97799..423ceb8a 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -1,73 +1,130 @@
#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-shared FLOAT_TYPE tmp[32];
+shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
+ const uint y_idx = i * QUANT_K + y_offset;
+
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ csel ^= 1;
+
+ if (!all_threads) { // when we don't have enough blocks to use all threads
+ if (i < num_blocks_per_row) {
+ const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
+ sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
+ sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+ }
+ barrier();
+
+ if (i >= num_blocks_per_row)
+ continue;
+ } else {
+ const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
+ sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
+ sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+ barrier();
+ }
+
+ const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
+ const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
+ const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
+ const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
+ const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
+
+ vec2 d = vec2(data_a[ib0 + i].d);
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
+ vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
+ vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
+ vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
+ vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
+ vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
+ vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
+ vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
+
+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
+ [[unroll]] for (int l = 0; l < 2; ++l) {
+ sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
+ fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
+ fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
+ fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
+ fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
+ fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
+ fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
+ fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
+ sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
+ fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
+ fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
+ fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
+ fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
+ fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
+ fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
+ fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
+ }
+ temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
+ }
+ }
+}
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ // 16 threads are used to process each block
+ const uint it_size = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid%16; // 0...15
+ const uint ix = tid/16;
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = itid - 8*v_im; // 0...7
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
- const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
+ const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
- const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0;
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
- sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
- sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
}
- tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
}
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
+ const uint nbr_par_th = num_blocks_per_row%it_size;
+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
+ uint i0 = 0;
+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
+ calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
+ calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
}
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
+ compute_outputs(first_row, p.stride_d - first_row);
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp
index a6e430ea..e91724a2 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp
@@ -1,66 +1,132 @@
#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-shared FLOAT_TYPE tmp[32];
+shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
+ const uint y_idx = i * QUANT_K + y_offset;
+
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ csel ^= 1;
+
+ if (!all_threads) { // when we don't have enough blocks to use all threads
+ if (i < num_blocks_per_row)
+ sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+ barrier();
+
+ if (i >= num_blocks_per_row)
+ continue;
+ }
+
+ const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
+ const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
+ const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
+ const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
+ const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
+
+ // 0, 1, 16, 17
+ uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
+ qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
+ const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
+ const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
+ const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
+ const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
+
+ if (all_threads) {
+ sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+ barrier();
+ }
+
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
+ vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
+ vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
+ vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
+ vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
+ vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
+ vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
+ vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
+
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ [[unroll]] for (int l = 0; l < 2; ++l) {
+ sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
+ fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
+ fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
+ fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
+ fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
+ fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
+ fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
+ fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
+ }
+ temp[j][n] = fma(d, sum, temp[j][n]);
+ }
+ }
+}
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+ // 16 threads are used to process each block
+ const uint it_size = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid%16; // 0...15
+ const uint ix = tid/16;
+ const uint itid8 = itid%8;
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_im4 = v_im*4;
+ const uint v_in = itid - 8*v_im; // 0...7
- const uint8_t m = uint8_t(1 << (4 * v_im));
+ const uint32_t m = 0x01010101 << (4 * v_im);
+ uint32_t hm_m[4];
+ [[unroll]] for (uint j = 0; j < 4; ++j)
+ hm_m[j] = m << j;
- const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
+ const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- const uint s_shift = 4 * v_im;
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
+ const uint s_shift = v_im4 + 2*(itid8/4);
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+ const uint nbr_par_th = num_blocks_per_row%it_size;
+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
+ uint i0 = 0;
+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
+ calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
+ calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
- sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
- }
- tmp[16 * ix + tid] += d * sum;
- }
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
}
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
+ compute_outputs(first_row, p.stride_d - first_row);
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
index 75569363..f9cde064 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -1,28 +1,106 @@
#version 450
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-shared FLOAT_TYPE tmp[32];
+#include "mul_mat_vec_base.comp"
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y1_idx = i * QUANT_K + y_offset;
+ const uint y2_idx = y1_idx + 128;
+
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ vec2 d = vec2(data_a[ib0 + i].d);
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+ const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
+ const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
+ const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
+
+ const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
+ const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
+ const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
+ const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
+
+ const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
+ const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
+ const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
+ const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
+ const FLOAT_TYPE sc4 = scale8_f.x;
+ const FLOAT_TYPE sc5 = scale8_f.y;
+ const FLOAT_TYPE sc6 = scale8_f.z;
+ const FLOAT_TYPE sc7 = scale8_f.w;
+
+ const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
+ const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
+
+ const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
+ const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
+ const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
+ const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
+
+ const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
+ const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
+ const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
+ const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
+
+ const FLOAT_TYPE q4_0 = qs0_lo4.x;
+ const FLOAT_TYPE q4_1 = qs0_lo4.y;
+ const FLOAT_TYPE q4_2 = qs0_lo4.z;
+ const FLOAT_TYPE q4_3 = qs0_lo4.w;
+ const FLOAT_TYPE q4_4 = qs0_hi4.x;
+ const FLOAT_TYPE q4_5 = qs0_hi4.y;
+ const FLOAT_TYPE q4_6 = qs0_hi4.z;
+ const FLOAT_TYPE q4_7 = qs0_hi4.w;
+ const FLOAT_TYPE q4_8 = qs64_lo4.x;
+ const FLOAT_TYPE q4_9 = qs64_lo4.y;
+ const FLOAT_TYPE q4_10 = qs64_lo4.z;
+ const FLOAT_TYPE q4_11 = qs64_lo4.w;
+ const FLOAT_TYPE q4_12 = qs64_hi4.x;
+ const FLOAT_TYPE q4_13 = qs64_hi4.y;
+ const FLOAT_TYPE q4_14 = qs64_hi4.z;
+ const FLOAT_TYPE q4_15 = qs64_hi4.w;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
+ vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
+ vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
+ vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
+
+ const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
+ const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
+ const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
+ const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
+ const FLOAT_TYPE smin =
+ fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
+ fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
+ fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
+ fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ }
+ }
+}
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ // 16 threads are used to process each block
+ const uint it_size = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid%16; // 0...15
+ const uint ix = tid/16;
- const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
-
- const uint il = tid/step; // 0...3
- const uint ir = tid - step*il; // 0...7 or 0...3
- const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
+ const uint il = itid/4; // 0...3
+ const uint ir = itid - 4*il; // 0...3
+ const uint n = 4;
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2;
@@ -31,85 +109,28 @@ void main() {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y1_idx = i * QUANT_K + y_offset;
- const uint y2_idx = y1_idx + 128;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-#if K_QUANTS_PER_ITERATION == 2
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
-
- const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
- const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
- const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
- const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
- const FLOAT_TYPE smin = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
- );
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
-#else
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
-
- const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
- const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
- const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
- const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
- const FLOAT_TYPE smin = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
- );
-
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
-#endif
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
+ calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
}
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
+ compute_outputs(first_row, p.stride_d - first_row);
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp
index 9be3645b..6c84ef3c 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -1,25 +1,137 @@
#version 450
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-shared FLOAT_TYPE tmp[32];
+#include "mul_mat_vec_base.comp"
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+ const uint y1_idx = i * QUANT_K + y_offset;
+ const uint y2_idx = y1_idx + 128;
+
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ vec2 d = vec2(data_a[ib0 + i].d);
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+ const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
+ const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
+ const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
+
+ const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
+ const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
+ const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
+ const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
+
+ const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
+ const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
+ const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
+ const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
+ const FLOAT_TYPE sc4 = scale8_f.x;
+ const FLOAT_TYPE sc5 = scale8_f.y;
+ const FLOAT_TYPE sc6 = scale8_f.z;
+ const FLOAT_TYPE sc7 = scale8_f.w;
+
+ const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
+ const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
+
+ uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
+ uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
+ uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
+ uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
+
+ const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
+
+ const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
+ const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
+ const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
+ const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
+
+ qs0_16_u32_lo4 += qs0_16_lo4_offset16;
+ qs0_16_u32_hi4 += qs0_16_hi4_offset16;
+ qs64_80_u32_lo4 += qs64_80_lo4_offset16;
+ qs64_80_u32_hi4 += qs64_80_hi4_offset16;
+
+ const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
+ const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
+ const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
+ const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
+
+ const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
+ const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
+ const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
+ const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
+ const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
+ const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
+ const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
+ const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
+ const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
+ const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
+ const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
+ const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
+ const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
+ const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
+ const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
+ const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
+ vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
+ vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
+ vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
+ vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
+ vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
+ vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
+ vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
+
+ const FLOAT_TYPE sx =
+ fma(FLOAT_TYPE(by10.x), q4_0,
+ fma(FLOAT_TYPE(by10.y), q4_1,
+ fma(FLOAT_TYPE(by116.x), q4_2,
+ FLOAT_TYPE(by116.y) * q4_3)));
+ const FLOAT_TYPE sy =
+ fma(FLOAT_TYPE(by132.x), q4_4,
+ fma(FLOAT_TYPE(by132.y), q4_5,
+ fma(FLOAT_TYPE(by148.x), q4_6,
+ FLOAT_TYPE(by148.y) * q4_7)));
+ const FLOAT_TYPE sz =
+ fma(FLOAT_TYPE(by20.x), q4_8,
+ fma(FLOAT_TYPE(by20.y), q4_9,
+ fma(FLOAT_TYPE(by216.x), q4_10,
+ FLOAT_TYPE(by216.y) * q4_11)));
+ const FLOAT_TYPE sw =
+ fma(FLOAT_TYPE(by232.x), q4_12,
+ fma(FLOAT_TYPE(by232.y), q4_13,
+ fma(FLOAT_TYPE(by248.x), q4_14,
+ FLOAT_TYPE(by248.y) * q4_15)));
+ const FLOAT_TYPE smin =
+ fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
+ fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
+ fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
+ (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ }
+ }
+}
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
- const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
+ // 16 threads are used to process each block
+ const uint it_size = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid%16; // 0...15
+ const uint ix = tid/16;
- const uint il = tid/4; // 0...3
- const uint ir = tid - 4*il; // 0...7 or 0...3
+ const uint il = itid/4; // 0...3
+ const uint ir = itid - 4*il; // 0...3
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2;
@@ -28,84 +140,28 @@ void main() {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
- const uint8_t hm1 = uint8_t(1 << (2*v_im));
- const uint8_t hm2 = uint8_t(hm1 << 4);
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
- const uint y1_idx = i * QUANT_K + y_offset;
- const uint y2_idx = y1_idx + 128;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
-
- const FLOAT_TYPE sx = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE sy = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE sz = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE sw = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE smin = FLOAT_TYPE(
- (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3
- + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7
- );
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
+ calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
}
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
+ compute_outputs(first_row, p.stride_d - first_row);
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp
index d610cf03..d53d9ee0 100644
--- a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp
@@ -1,79 +1,130 @@
#version 450
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
#include "mul_mat_vec_base.comp"
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-shared FLOAT_TYPE tmp[32];
+shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
+ const uint y_idx = i * QUANT_K + y_offset;
+
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ csel ^= 1;
+
+ if (!all_threads) { // when we don't have enough blocks to use all threads
+ if (i < num_blocks_per_row)
+ sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+ barrier();
+
+ if (i >= num_blocks_per_row)
+ continue;
+ }
+
+ const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
+ const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
+
+ const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
+ const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
+ const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
+ const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
+
+ const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
+ const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
+ const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
+ const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
+ const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
+
+ const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
+ const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
+ const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
+ const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
+
+ const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
+ const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
+ const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
+ const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
+
+ if (all_threads) {
+ sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+ barrier();
+ }
+
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
+ vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
+ vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
+ vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
+
+ FLOAT_TYPE sum[4] = {0, 0, 0, 0};
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
+ sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
+ sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
+ sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
+ sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
+ }
+ temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
+ }
+ }
+}
+void compute_outputs(const uint first_row, const uint num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ // 16 threads are used to process each block
+ const uint it_size = gl_WorkGroupSize.x/16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint itid = tid%16; // 0...15
+ const uint ix = tid/16;
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = itid - 8*v_im; // 0...7
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
-#if K_QUANTS_PER_ITERATION == 1
- const uint l0 = v_in; // 0...15
- const uint is = 0;
-#else
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
const uint is = v_in / 4;
-#endif
const uint ql_offset = 64*v_im + l0;
const uint qh_offset = 32*v_im + l0;
const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0;
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
+ const uint nbr_par_th = num_blocks_per_row%it_size;
+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
+ uint i0 = 0;
+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
+ calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
+ calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
-#if K_QUANTS_PER_ITERATION == 1
- FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
- tmp[16 * ix + tid] += sum;
-#else
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- [[unroll]] for (int l = 0; l < 4; ++l) {
- sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
}
- tmp[16 * ix + tid] += sum;
-#endif
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
+ compute_outputs(first_row, p.stride_d - first_row);
}
}
diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp
index 5fe9d524..26163b16 100644
--- a/ggml/src/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/vulkan-shaders/mul_mm.comp
@@ -6,6 +6,19 @@
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
+#if defined(DATA_A_IQ1_M)
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#endif
+
+#if defined(DATA_A_BF16) && defined(COOPMAT)
+#extension GL_EXT_bfloat16 : enable
+#endif
+
+#ifdef COOPMAT
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#endif
#ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
@@ -20,9 +33,20 @@
#define LOAD_VEC_B 1
#endif
+#if !defined(TO_FLOAT_TYPE)
+#define TO_FLOAT_TYPE FLOAT_TYPE
+#endif
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
@@ -57,6 +81,7 @@ layout (push_constant) uniform parameter
#endif
} p;
+layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@@ -65,16 +90,33 @@ layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2;
-layout (constant_id = 9) const uint WARP = 32;
+layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
+layout (constant_id = 10) const uint WARP = 32;
-shared FLOAT_TYPE buf_a[BM * (BK+1)];
-shared FLOAT_TYPE buf_b[BN * (BK+1)];
+#ifdef COOPMAT
+#define SHMEM_STRIDE (BK + 8)
+#else
+#define SHMEM_STRIDE (BK + 1)
+#endif
+
+shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
+shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[3072];
+shared u16vec2 row_ids[4096];
+#endif // MUL_MAT_ID
+
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
+#ifdef COOPMAT
+shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
@@ -94,17 +136,32 @@ void main() {
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
- const uint warp_i = gl_LocalInvocationID.x / WARP;
- const uint warp_r = warp_i % (BM / WM);
- const uint warp_c = warp_i / (BM / WM);
-
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
+#ifdef COOPMAT
+ const uint warp_i = gl_SubgroupID;
+
+ const uint tiw = gl_SubgroupInvocationID;
+
+ const uint cms_per_row = WM / TM;
+ const uint cms_per_col = WN / TN;
+
+ const uint storestride = WARP / TM;
+ const uint store_r = tiw % TM;
+ const uint store_c = tiw / TM;
+#else
+ const uint warp_i = gl_LocalInvocationID.x / WARP;
+
const uint tiw = gl_LocalInvocationID.x % WARP;
+
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
+#endif
+
+ const uint warp_r = warp_i % (BM / WM);
+ const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@@ -152,21 +209,31 @@ void main() {
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif
- float sums[WMITER * TM * WNITER * TN];
+#ifdef COOPMAT
+ coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
+ coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
+
+ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
+ sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
+ }
+#else
+ ACC_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
- FLOAT_TYPE cache_b[WNITER * TN];
+ FLOAT_TYPE cache_b[TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = 0.0f;
+ sums[i] = ACC_TYPE(0.0f);
}
+#endif
- [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
+ for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@@ -177,91 +244,132 @@ void main() {
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
+ } else {
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
+ }
+#endif
+#elif defined(DATA_A_BF16)
+#if LOAD_VEC_A == 4
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+ buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
+ buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
+ buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
+ buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
+#else
+ if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else {
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
+
+ const uint ib = idx / 4;
+ const uint iqs = idx & 0x03;
+
+ const float d = float(data_a_packed16[ib].d);
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
+ const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
+ const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
+ buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
+ buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
+ buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
+ buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
+
+ const uint ib = idx / 4;
+ const uint iqs = idx & 0x03;
+
+ const float d = float(data_a_packed16[ib].d);
+ const float m = float(data_a_packed16[ib].m);
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
+ const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
+ const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
+ buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
+ buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
+ buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
+ buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+ const float d = float(data_a_packed16[ib].d);
+ const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+ const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint uint_qh = data_a[ib].qh;
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+ const float d = float(data_a_packed16[ib].d);
+ const float m = float(data_a_packed16[ib].m);
+ const uint uint_qh = data_a_packed16[ib].qh;
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+ const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 16;
- const uint iqs = (idx & 0xF) * 2;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
+ const float d = float(data_a_packed16[ib].d);
+ const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
+ const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
+ const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -280,7 +388,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q3_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -294,17 +402,15 @@ void main() {
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
- const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
- is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
- is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
- (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
+ const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
+ | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[ib].d) * float(us - 32);
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
#elif defined(DATA_A_Q4_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -316,23 +422,28 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d);
- uint8_t sc;
- uint8_t mbyte;
- if (is < 4) {
- sc = uint8_t(data_a[ib].scales[is ] & 63);
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
- }
+ const uint scidx0 = (is < 4) ? is : (is + 4);
+ const uint scidx1 = (is < 4) ? is : (is - 4);
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
+ const uint mbidx0 = is + 4;
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
+
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
const float d = loadd.x * sc;
- const float m = loadd.y * mbyte;
+ const float m = -loadd.y * mbyte;
- buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
- buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);
+ buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
+ buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -347,23 +458,28 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d);
- uint8_t sc;
- uint8_t mbyte;
- if (is < 4) {
- sc = uint8_t(data_a[ib].scales[is ] & 63);
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
- }
+ const uint scidx0 = (is < 4) ? is : (is + 4);
+ const uint scidx1 = (is < 4) ? is : (is - 4);
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
+ const uint mbidx0 = is + 4;
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
+
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
const float d = loadd.x * sc;
- const float m = loadd.y * mbyte;
+ const float m = -loadd.y * mbyte;
- buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
- buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);
+ buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
+ buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -380,19 +496,201 @@ void main() {
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
-#elif defined(DATA_A_IQ4_NL)
+#elif defined(DATA_A_IQ1_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
+ const uint ib = idx / 128; // 2 values per idx
+ const uint ib32 = (idx % 128) / 16; // 0..7
+ const uint ib8 = (idx % 128) / 4;
+ const int i8 = 2 * int(idx % 4);
const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
+ const uint qh = data_a[ib].qh[ib32];
+ const uint qs = data_a[ib].qs[ib8];
+ const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+ const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
+
+ const ivec2 gvec = ivec2(
+ bitfieldExtract(grid, 2 * (i8), 2),
+ bitfieldExtract(grid, 2 * (i8 + 1), 2)
+ );
+ const vec2 v = dl * (vec2(gvec) + delta);
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ1_M)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint ib8 = (idx % 128) / 4;
+ const uint ib16 = ib8 / 2;
+ const int i8 = 2 * int(idx % 4);
+
+ const uint16_t[4] scales = data_a[ib].scales;
+ const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
+ const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
+ const uint sc = scales[ib8 / 8];
+ const uint qs = data_a[ib].qs[ib8];
+ const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
+ const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
+ const ivec2 gvec = ivec2(
+ bitfieldExtract(grid, 2 * (i8), 2),
+ bitfieldExtract(grid, 2 * (i8 + 1), 2)
+ );
+ const vec2 v = dl * (vec2(gvec) + delta);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ2_XXS)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint ib32 = (idx % 128) / 16; // 0..7
+ const uint ib8 = (idx / 4) % 4;
+
+ const float d = float(data_a[ib].d);
+ const uint qs = data_a[ib].qs[8 * ib32 + ib8];
+ const uint signs = pack32(u8vec4(
+ data_a[ib].qs[8*ib32 + 4],
+ data_a[ib].qs[8*ib32 + 5],
+ data_a[ib].qs[8*ib32 + 6],
+ data_a[ib].qs[8*ib32 + 7]
+ ));
+ const float db = d * 0.25 * (0.5 + (signs >> 28));
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
+ const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ2_XS)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint ib32 = (idx % 128) / 16; // 0..7
+ const uint ib8 = (idx / 4) % 4; // 0..3
+
+ const float d = float(data_a[ib].d);
+ const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
+ const float db = d * 0.25 * (0.5 + scale);
+ const uint qs = data_a[ib].qs[4 * ib32 + ib8];
+ const uint sign7 = qs >> 9;
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
+ const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ2_S)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint ib8 = (idx % 128) / 4; // 0..31
+ const uint ib32 = ib8 / 4; // 0..7
+
+ const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
+ const uint qs = data_a[ib].qs[ib8];
+ const uint qh = data_a[ib].qh[ib32];
+ const uint qhshift = 2 * (ib8 % 4);
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
+
+ const float d = float(data_a[ib].d);
+ const float db = d * 0.25 * (0.5 + scale);
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
+ const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ3_XXS)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = (idx % 128) / 2; // 0..63
+ const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
+
+ const float d = float(data_a[ib].d);
+ const uint qs = data_a[ib].qs[iqs];
+ const uint signs = pack32(u8vec4(
+ data_a[ib].qs[is+0],
+ data_a[ib].qs[is+1],
+ data_a[ib].qs[is+2],
+ data_a[ib].qs[is+3]
+ ));
+ const float db = d * 0.5 * (0.5 + (signs >> 28));
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
+ const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ3_S)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = (idx % 128) / 2; // 0..63
+ const uint iqh = iqs / 8;
+
+ const float d = float(data_a[ib].d);
+ const uint qs = data_a[ib].qs[iqs];
+ const uint qh = data_a[ib].qh[iqh];
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
+ const uint scale = data_a[ib].scales[iqs / 16];
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
+ const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
+ const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ4_XS)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint ib32 = (idx % 128) / 16; // 0..7
+ const uint iq = 16 * ib32 + 2 * (idx % 8);
+
+ const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+ const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
+ const uint qshift = (idx & 8) >> 1;
+ u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
+ qs = (qs >> qshift) & uint8_t(0xF);
+
+ const float d = float(data_a[ib].d);
+ const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ4_NL)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
+
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
+
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
+ buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
+ buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
@@ -403,7 +701,7 @@ void main() {
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@@ -419,24 +717,24 @@ void main() {
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
+ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
+ buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
+ buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
+ buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
+ buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i];
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
#endif
}
@@ -446,29 +744,43 @@ void main() {
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
- for (uint i = 0; i < BK; i++) {
+#ifdef COOPMAT
+ [[unroll]] for (uint i = 0; i < BK; i += TK) {
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+ // Load from shared into cache
+ coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
+
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
+
+ sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
+ }
+ }
+ }
+#else
+ [[unroll]] for (uint i = 0; i < BK; i++) {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) {
- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
+ cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) {
- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
+ cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
}
- }
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
- sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
}
}
}
}
}
+#endif
barrier();
}
@@ -480,6 +792,54 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
+#ifdef COOPMAT
+#ifdef MUL_MAT_ID
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+ const uint row_i = dc + cm_col * TN + col + store_c;
+ if (row_i >= _ne1) break;
+
+ const u16vec2 row_idx = row_ids[row_i];
+
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+ }
+ }
+ }
+#else
+ const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
+
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+ const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
+
+ if (is_aligned && is_in_bounds) {
+ // Full coopMat is within bounds and stride_d is aligned with 16B
+ coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
+ coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
+ } else if (is_in_bounds) {
+ // Full coopMat is within bounds, but stride_d is not aligned
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+ }
+ } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
+ // Partial coopMat is within bounds
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+ if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+ }
+ }
+ }
+ }
+ }
+#endif // MUL_MAT_ID
+#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@@ -491,7 +851,7 @@ void main() {
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
-#endif
+#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@@ -499,9 +859,10 @@ void main() {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
-#endif
+#endif // MUL_MAT_ID
}
}
}
}
+#endif // COOPMAT
}
diff --git a/ggml/src/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/vulkan-shaders/mul_mm_cm2.comp
new file mode 100644
index 00000000..91846575
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mm_cm2.comp
@@ -0,0 +1,441 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_EXT_buffer_reference : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_vote : enable
+#ifdef DATA_A_BF16
+#extension GL_EXT_bfloat16 : enable
+#endif
+
+#include "types.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+#define IS_MUL_MM2 1
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 256;
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
+
+layout (constant_id = 4) const bool enable_smaller_matrices = false;
+const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
+const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint N;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint nei1;
+ uint nbi1;
+ uint ne11;
+#else
+ uint k_split;
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+ // N dimension for the B matrix can be >= p.N
+ uint padded_N;
+} p;
+
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#if QUANT_K > 1
+#define DECODEFUNCA , dequantFuncA
+
+#include "dequant_funcs_cm2.comp"
+
+#else
+#define DECODEFUNCA
+#endif
+
+#if !defined(fetch_scales)
+#define fetch_scales(a, b, c, d, e, f)
+#endif
+#if !defined(store_scales)
+#define store_scales(a)
+#endif
+
+#if defined(DATA_A_BF16)
+#define MAT_TYPE bfloat16_t
+#else
+#define MAT_TYPE FLOAT_TYPE
+#endif
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+
+shared u16vec4 row_ids[4096];
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
+ B_TYPE b[];
+};
+
+uint _ne1;
+shared uint _ne1_sh;
+
+B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const uint row_i = blockCoords[0];
+
+ if (row_i >= _ne1) {
+ return B_TYPE(0.0);
+ }
+
+ const u16vec4 row_idx = row_ids[row_i];
+ B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
+
+ return ret;
+}
+
+D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
+{
+ uint dr = ir * BM + r;
+ uint dc = ic * BN + c;
+
+ if (dr < p.M && dc < _ne1) {
+ uint row_i = dc;
+ const u16vec4 row_idx = row_ids[row_i];
+ data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
+ }
+ return elem;
+}
+
+#endif
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ const uint tid = gl_LocalInvocationIndex;
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+#else
+ const uint batch_idx = gl_GlobalInvocationID.z;
+
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+ const uint blocks_m = (p.M + BM - 1) / BM;
+ const uint ir = gl_WorkGroupID.x % blocks_m;
+ const uint ik = gl_WorkGroupID.x / blocks_m;
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ // Spread the search across all elements in the first subgroup
+ if (gl_SubgroupID == 0) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+
+ for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
+ bool in_range = i < num_elements;
+ uint ii0 = i % p.nei0;
+ uint ii1 = i / p.nei0;
+ uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx) {
+ row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
+ }
+ _ne1 += subgroupBallotBitCount(ballot);
+ }
+ _ne1_sh = _ne1;
+ }
+
+ barrier();
+
+ _ne1 = _ne1_sh;
+
+ // Workgroup has no work
+ if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+ uint start_k = 0;
+ const uint end_k = p.K;
+#else
+ uint start_k = ik * p.k_split;
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+#ifdef MUL_MAT_ID
+ uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
+ uint pos_b = 0;
+#else
+ uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
+ uint pos_b = batch_idx * p.batch_stride_b;
+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+#endif
+
+ uint stride_a = p.stride_a / QUANT_K;
+ uint stride_b = p.stride_b;
+
+ // Hint to the compiler that values are aligned (want 16B alignment).
+ // Quants are always block-aligned, no alignment needed.
+#if ALIGNED
+#if QUANT_K == 1
+ stride_a &= ~7;
+#endif
+ stride_b &= ~7;
+#endif
+
+ // Create layouts for both clamped and unclamped accesses
+ tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
+
+#if QUANT_K > 1
+ tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
+ tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
+#endif
+
+ // Use end_k rather than p.K as the dimension because that's what
+ // we need to bound check against when using split_k.
+ // Bounds check B against padded_N, but bounds check D against N.
+ tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
+ tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
+ tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
+ tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
+
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
+
+#if !defined(MUL_MAT_ID)
+
+ const uint START_ALIGN_K = 256;
+ // For Qi_K (block size 256), unroll whole 256 element tiles.
+ // For legacy quants (block size 32), unroll 8x.
+ const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
+ const uint unroll_count = UNROLL_K / BK;
+
+ // Detect a fast path where all loads are entirely in bounds and no clamping is required
+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
+#if QUANT_K == 1
+ (stride_a % 8) == 0 &&
+#endif
+ (stride_b % 8) == 0) {
+ // Hint to the compiler that values are aligned (want 16B alignment)
+ start_k &= ~(START_ALIGN_K-1);
+ stride_b &= ~7;
+#if QUANT_K == 1
+ stride_a &= ~7;
+#endif
+
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
+
+ uint k_iters = (end_k - start_k) / UNROLL_K;
+ uint block_k = start_k;
+
+ // fetch scale values for a tile of quants. These will be copied into shared memory.
+ // The fetches and stores are pipelined to hide the latency.
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
+
+ if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
+
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
+ return;
+ } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
+
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
+ return;
+ } else {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
+
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
+
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
+ return;
+ }
+ } else
+#endif // !defined(MUL_MAT_ID)
+ {
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
+
+ tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
+
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
+
+ tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
+
+ uint k_iters = (end_k - start_k + BK - 1) / BK;
+
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ store_scales(tid);
+ if (block_k + BK < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+#endif
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
+
+#ifdef MUL_MAT_ID
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+#else
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
+#endif
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mmq.comp b/ggml/src/vulkan-shaders/mul_mmq.comp
new file mode 100644
index 00000000..83de90eb
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mmq.comp
@@ -0,0 +1,442 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+
+#extension GL_EXT_integer_dot_product : require
+
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#endif
+
+#ifdef COOPMAT
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#endif
+
+#ifdef MUL_MAT_ID
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#endif
+
+#include "types.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+#endif
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint N;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint nei1;
+ uint nbi1;
+ uint ne11;
+#else
+ uint k_split;
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+} p;
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 64;
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+// layout (constant_id = 3) const uint BK = 32;
+layout (constant_id = 4) const uint WM = 32;
+layout (constant_id = 5) const uint WN = 32;
+layout (constant_id = 6) const uint WMITER = 2;
+layout (constant_id = 7) const uint TM = 4;
+layout (constant_id = 8) const uint TN = 2;
+layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
+layout (constant_id = 10) const uint WARP = 32;
+
+#define BK 32
+
+#ifdef COOPMAT
+#define SHMEM_STRIDE (BK / 4 + 4)
+#else
+#define SHMEM_STRIDE (BK / 4 + 1)
+#endif
+
+shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
+
+#ifndef COOPMAT
+#if QUANT_AUXF == 1
+shared FLOAT_TYPE buf_a_dm[BM];
+#else
+shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
+#endif
+#endif
+
+shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
+#ifndef COOPMAT
+shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
+#endif
+
+#define LOAD_VEC_A (4 * QUANT_R)
+#define LOAD_VEC_B 4
+
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[4096];
+#endif // MUL_MAT_ID
+
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
+#ifdef COOPMAT
+shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
+#endif
+
+#include "mul_mmq_funcs.comp"
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+#else
+ const uint batch_idx = gl_GlobalInvocationID.z;
+
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+ const uint blocks_m = (p.M + BM - 1) / BM;
+ const uint ir = gl_WorkGroupID.x % blocks_m;
+ const uint ik = gl_WorkGroupID.x / blocks_m;
+ const uint ic = gl_WorkGroupID.y;
+
+ const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
+ const uint WSUBM = WM / WMITER;
+ const uint WSUBN = WN / WNITER;
+
+#ifdef COOPMAT
+ const uint warp_i = gl_SubgroupID;
+
+ const uint tiw = gl_SubgroupInvocationID;
+
+ const uint cms_per_row = WM / TM;
+ const uint cms_per_col = WN / TN;
+
+ const uint storestride = WARP / TM;
+ const uint store_r = tiw % TM;
+ const uint store_c = tiw / TM;
+#else
+ const uint warp_i = gl_LocalInvocationID.x / WARP;
+
+ const uint tiw = gl_LocalInvocationID.x % WARP;
+
+ const uint tiwr = tiw % (WSUBM / TM);
+ const uint tiwc = tiw / (WSUBM / TM);
+#endif
+
+ const uint warp_r = warp_i % (BM / WM);
+ const uint warp_c = warp_i / (BM / WM);
+
+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+
+ const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
+ const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
+
+#ifdef MUL_MAT_ID
+ uint _ne1 = 0;
+ for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+ if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
+ row_ids[_ne1] = u16vec2(ii0, ii1);
+ _ne1++;
+ }
+ }
+ }
+
+ barrier();
+
+ // Workgroup has no work
+ if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+ const uint start_k = 0;
+ const uint end_k = p.K;
+#else
+ const uint start_k = ik * p.k_split;
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+ uint pos_a_ib = (
+#ifdef MUL_MAT_ID
+ expert_idx * p.batch_stride_a +
+#else
+ batch_idx_a * p.batch_stride_a +
+#endif
+ ir * BM * p.stride_a + start_k) / BK;
+#ifdef MUL_MAT_ID
+ uint pos_b_ib = 0;
+#else
+ uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
+#endif
+
+#ifdef COOPMAT
+ coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
+ coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
+ coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
+
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
+
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
+
+ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
+ sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
+ }
+#else
+ int32_t cache_a_qs[WMITER * TM * BK / 4];
+
+ int32_t cache_b_qs[TN * BK / 4];
+
+ ACC_TYPE sums[WMITER * TM * WNITER * TN];
+
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+ sums[i] = ACC_TYPE(0.0f);
+ }
+#endif
+
+#if QUANT_AUXF == 1
+ FLOAT_TYPE cache_a_dm[WMITER * TM];
+#else
+ FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
+#endif
+
+ FLOAT_TYPE_VEC2 cache_b_ds[TN];
+
+ for (uint block = start_k; block < end_k; block += BK) {
+ [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
+ const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
+ const uint iqs = loadr_a;
+ const uint buf_ib = loadc_a + l;
+
+ if (iqs == 0) {
+#if QUANT_AUXF == 1
+ buf_a_dm[buf_ib] = get_d(ib);
+#else
+ buf_a_dm[buf_ib] = get_dm(ib);
+#endif
+ }
+#if QUANT_R == 1
+ buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
+#else
+ const i32vec2 vals = repack(ib, iqs);
+ buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
+ buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
+#endif
+ }
+ [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+#ifdef MUL_MAT_ID
+ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x7;
+#else
+ const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
+ const uint iqs = loadr_b;
+#endif
+
+ const uint buf_ib = loadc_b + l;
+
+ if (iqs == 0) {
+ buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
+ }
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
+ }
+
+ barrier();
+
+ pos_a_ib += 1;
+ pos_b_ib += 1;
+
+#ifdef COOPMAT
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+ const uint ib_a = warp_r * WM + cm_row * TM;
+ // Load from shared into cache
+ coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
+
+ // TODO: only cache values that are actually needed
+ [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
+ cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
+ }
+
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+ const uint ib_b = warp_c * WN + cm_col * TN;
+ coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
+
+ // TODO: only cache values that are actually needed
+ [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
+ cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
+ }
+
+ cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
+ cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
+
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+ coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
+ }
+
+ coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+ sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
+ }
+ }
+#else
+ // Load from shared into cache
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
+ cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
+ cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
+ }
+ }
+ }
+
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+ const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+ cache_b_ds[cc] = buf_b_ds[ib];
+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
+ cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
+ }
+ }
+
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint cache_a_idx = wsir * TM + cr;
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+ int32_t q_sum = 0;
+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
+ q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
+ cache_b_qs[cc * (BK / 4) + idx_k]);
+ }
+
+ sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
+ }
+ }
+ }
+ }
+#endif
+
+ barrier();
+ }
+
+ const uint dr = ir * BM + warp_r * WM;
+ const uint dc = ic * BN + warp_c * WN;
+
+#ifndef MUL_MAT_ID
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+#endif
+
+#ifdef COOPMAT
+#ifdef MUL_MAT_ID
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+ [[unroll]] for (uint col = 0; col < BN; col += storestride) {
+ const uint row_i = dc + cm_col * TN + col + store_c;
+ if (row_i >= _ne1) break;
+
+ const u16vec2 row_idx = row_ids[row_i];
+
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+ }
+ }
+ }
+#else
+ const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
+
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+ const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
+
+ if (is_aligned && is_in_bounds) {
+ // Full coopMat is within bounds and stride_d is aligned with 16B
+ coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
+ coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
+ } else if (is_in_bounds) {
+ // Full coopMat is within bounds, but stride_d is not aligned
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+ }
+ } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
+ // Partial coopMat is within bounds
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+ if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+ }
+ }
+ }
+ }
+ }
+#endif // MUL_MAT_ID
+#else
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+
+ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
+ const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+#ifdef MUL_MAT_ID
+ const uint row_i = dc_warp + cc;
+ if (row_i >= _ne1) break;
+
+ const u16vec2 row_idx = row_ids[row_i];
+#endif // MUL_MAT_ID
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+#ifdef MUL_MAT_ID
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+#else
+ if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ }
+#endif // MUL_MAT_ID
+ }
+ }
+ }
+ }
+#endif // COOPMAT
+}
diff --git a/ggml/src/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/vulkan-shaders/mul_mmq_funcs.comp
new file mode 100644
index 00000000..63b15471
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mmq_funcs.comp
@@ -0,0 +1,99 @@
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+
+#include "types.comp"
+
+// Each iqs value maps to a 32-bit integer
+
+#if defined(DATA_A_Q4_0)
+i32vec2 repack(uint ib, uint iqs) {
+ // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
+ const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
+ data_a[ib].qs[iqs * 2 + 1]);
+ const uint32_t vui = pack32(quants);
+ return i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
+}
+#endif
+
+#if defined(DATA_A_Q4_1)
+i32vec2 repack(uint ib, uint iqs) {
+ // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
+ return i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
+}
+#endif
+
+#if defined(DATA_A_Q5_0)
+i32vec2 repack(uint ib, uint iqs) {
+ // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
+ const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
+ data_a[ib].qs[iqs * 2 + 1]);
+ const uint32_t vui = pack32(quants);
+ const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
+ const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+
+ const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+ return i32vec2(v0, v1);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
+}
+#endif
+
+#if defined(DATA_A_Q5_1)
+i32vec2 repack(uint ib, uint iqs) {
+ // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
+ const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
+ const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+
+ const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+ return i32vec2(v0, v1);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+int32_t repack(uint ib, uint iqs) {
+ // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
+ return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
+ data_a[ib].qs[iqs * 2 + 1]));
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
+ return ACC_TYPE(float(q_sum) * da * dsb.x);
+}
+#endif
+
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
+FLOAT_TYPE get_d(uint ib) {
+ return FLOAT_TYPE(data_a[ib].d);
+}
+#endif
+
+#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
+FLOAT_TYPE_VEC2 get_dm(uint ib) {
+ return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+}
+#endif
diff --git a/ggml/src/vulkan-shaders/opt_step_adamw.comp b/ggml/src/vulkan-shaders/opt_step_adamw.comp
new file mode 100644
index 00000000..e0214fe7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/opt_step_adamw.comp
@@ -0,0 +1,42 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) buffer X {A_TYPE x[];};
+layout (binding = 1) readonly buffer G {A_TYPE grad[];};
+layout (binding = 2) buffer GM {A_TYPE gradm[];};
+layout (binding = 3) buffer GV {A_TYPE gradv[];};
+layout (binding = 4) readonly buffer P {float params[7];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float alpha = params[0];
+ const float beta1 = params[1];
+ const float beta2 = params[2];
+ const float eps = params[3];
+ const float wd = params[4];
+ const float beta1h = params[5];
+ const float beta2h = params[6];
+
+ const float gi = grad[i];
+ const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
+ const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
+
+ gradm[i] = gmi;
+ gradv[i] = gvi;
+
+ const float mh = gmi*beta1h;
+ const float vh = sqrt(gvi*beta2h) + eps;
+
+ x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
+}
diff --git a/ggml/src/vulkan-shaders/pad.comp b/ggml/src/vulkan-shaders/pad.comp
index a465cd52..450b67fc 100644
--- a/ggml/src/vulkan-shaders/pad.comp
+++ b/ggml/src/vulkan-shaders/pad.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_unary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
@@ -22,5 +24,5 @@ void main() {
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
- data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
+ data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
}
diff --git a/ggml/src/vulkan-shaders/pool2d.comp b/ggml/src/vulkan-shaders/pool2d.comp
new file mode 100644
index 00000000..b6124411
--- /dev/null
+++ b/ggml/src/vulkan-shaders/pool2d.comp
@@ -0,0 +1,74 @@
+#version 450
+
+#include "types.comp"
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout(push_constant) uniform parameter {
+ uint IW; uint IH;
+ uint OW; uint OH;
+ uint OC;
+ uint pelements;
+ uint op;
+ int k0; int k1;
+ int s0; int s1;
+ int p0; int p1;
+} p;
+
+#define BLOCK_SIZE 512
+#define FLT_MAX 3.402823466e+38F
+#define OP_POOL_MAX 0u
+#define OP_POOL_AVG 1u
+
+layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout(binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.x;
+ if (idx >= p.pelements) {
+ return;
+ }
+
+ const uint O_HW = p.OW * p.OH;
+
+ const uint nc = idx / O_HW;
+ const uint cur_oh = (idx % O_HW) / p.OW;
+ const uint cur_ow = (idx % O_HW) % p.OW;
+
+ const int start_h = int(cur_oh) * p.s0 - p.p0;
+ const uint bh = max(start_h, 0);
+ const uint eh = min(start_h + p.k0, p.IH);
+
+ const int start_w = int(cur_ow) * p.s1 - p.p1;
+ const uint bw = max(start_w, 0);
+ const uint ew = min(start_w + p.k1, p.IW);
+
+ const float scale = 1.0 / float(p.k0 * p.k1);
+ float res;
+
+ if (p.op == OP_POOL_AVG) {
+ res = 0.0;
+ } else if (p.op == OP_POOL_MAX) {
+ res = -FLT_MAX;
+ } else {
+ return;
+ }
+
+ #pragma unroll
+ for (uint i = bh; i < eh; i++) {
+ #pragma unroll
+ for (uint j = bw; j < ew; j++) {
+ const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);
+
+ if (p.op == OP_POOL_AVG) {
+ res += cur * scale;
+ } else if (p.op == OP_POOL_MAX) {
+ res = max(res, cur);
+ }
+ }
+ }
+
+ data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;
+}
diff --git a/ggml/src/vulkan-shaders/quantize_q8_1.comp b/ggml/src/vulkan-shaders/quantize_q8_1.comp
new file mode 100644
index 00000000..e2e020fe
--- /dev/null
+++ b/ggml/src/vulkan-shaders/quantize_q8_1.comp
@@ -0,0 +1,77 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint ne;
+} p;
+
+#include "types.comp"
+
+layout(constant_id = 0) const uint GROUP_SIZE = 32;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {vec4 data_a[];};
+layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
+
+shared float shmem[GROUP_SIZE];
+
+void quantize() {
+ const uint wgid = gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ // Each thread handles a vec4, so 8 threads handle a block
+ const uint blocks_per_group = GROUP_SIZE / 8;
+
+ const uint block_in_wg = tid / 8;
+
+ const uint ib = wgid * blocks_per_group + block_in_wg;
+ const uint iqs = tid % 8;
+
+ if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
+ return;
+ }
+
+ const uint a_idx = ib * 8 + iqs;
+
+ vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
+ const vec4 abs_vals = abs(vals);
+
+ // Find absolute max for each block
+ shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
+ barrier();
+ [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
+ if (iqs < s) {
+ shmem[tid] = max(shmem[tid], shmem[tid + s]);
+ }
+ barrier();
+ }
+
+ const float amax = shmem[block_in_wg * 8];
+ const float d = amax / 127.0;
+ const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
+ vals = round(vals * d_inv);
+ data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
+ barrier();
+
+ // Calculate the sum for each block
+ shmem[tid] = vals.x + vals.y + vals.z + vals.w;
+ barrier();
+ [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
+ if (iqs < s) {
+ shmem[tid] += shmem[tid + s];
+ }
+ barrier();
+ }
+ if (iqs == 0) {
+ const float sum = shmem[tid];
+
+ data_b[ib].ds = f16vec2(vec2(d, sum * d));
+ }
+}
+
+void main() {
+ quantize();
+}
diff --git a/ggml/src/vulkan-shaders/relu.comp b/ggml/src/vulkan-shaders/relu.comp
index 52a19b62..4f806270 100644
--- a/ggml/src/vulkan-shaders/relu.comp
+++ b/ggml/src/vulkan-shaders/relu.comp
@@ -17,5 +17,5 @@ void main() {
return;
}
- data_d[i] = max(float(data_a[i]), 0);
+ data_d[i] = D_TYPE(max(float(data_a[i]), 0));
}
diff --git a/ggml/src/vulkan-shaders/repeat.comp b/ggml/src/vulkan-shaders/repeat.comp
new file mode 100644
index 00000000..1568b141
--- /dev/null
+++ b/ggml/src/vulkan-shaders/repeat.comp
@@ -0,0 +1,26 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+uint src0_idx_mod(uint idx) {
+ const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+ const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
+ const uint i12_offset = i12*p.ne11*p.ne10;
+ const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+ return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;
+}
+
+void main() {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]);
+}
diff --git a/ggml/src/vulkan-shaders/repeat_back.comp b/ggml/src/vulkan-shaders/repeat_back.comp
new file mode 100644
index 00000000..d8627993
--- /dev/null
+++ b/ggml/src/vulkan-shaders/repeat_back.comp
@@ -0,0 +1,37 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ // Destination multi-index (inlined dst_idx)
+ const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+ const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
+ const uint i12_offset = i12*p.ne11*p.ne10;
+ const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+ const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
+
+ // Accumulate from sources
+ A_TYPE acc = A_TYPE(0);
+ for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
+ for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
+ for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
+ for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
+ acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
+ }
+ }
+ }
+ }
+
+ data_d[get_doffset() + d_idx] = D_TYPE(acc);
+}
diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/vulkan-shaders/rms_norm.comp
index b554400b..deb8ee99 100644
--- a/ggml/src/vulkan-shaders/rms_norm.comp
+++ b/ggml/src/vulkan-shaders/rms_norm.comp
@@ -1,6 +1,6 @@
#version 450
-#include "generic_head.comp"
+#include "generic_unary_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,29 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint tid = gl_LocalInvocationID.x;
+ const uint ncols = p.ne00;
+ const uint nrows = gl_NumWorkGroups.x;
+ const uint nchannels = gl_NumWorkGroups.y;
+
+ const uint row = gl_WorkGroupID.x;
+ const uint channel = gl_WorkGroupID.y;
+ const uint samp = gl_WorkGroupID.z;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint stride_row = p.nb01;
+ const uint stride_channel = p.nb02;
+ const uint stride_sample = p.nb03;
+
+ uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
+ uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
sum[tid] += xi * xi;
}
@@ -33,10 +43,10 @@ void main() {
barrier();
}
- const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
+ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}
diff --git a/ggml/src/vulkan-shaders/rms_norm_back.comp b/ggml/src/vulkan-shaders/rms_norm_back.comp
new file mode 100644
index 00000000..76009f3d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rms_norm_back.comp
@@ -0,0 +1,55 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer G {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer X {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE sum_xx[BLOCK_SIZE];
+shared FLOAT_TYPE sum_xg[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5
+
+ // partial sums for thread in warp
+ sum_xx[tid] = FLOAT_TYPE(0.0f);
+ sum_xg[tid] = FLOAT_TYPE(0.0f);
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);
+ sum_xx[tid] += xi * xi;
+ sum_xg[tid] += xi * gi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum_xx[tid] += sum_xx[tid + s];
+ sum_xg[tid] += sum_xg[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);
+ const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);
+ const FLOAT_TYPE scale_g = inversesqrt(mean + eps);
+ const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ data_d[row*p.KX + col] = D_TYPE(
+ scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +
+ scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/rope_head.comp b/ggml/src/vulkan-shaders/rope_head.comp
index ea895422..96c9c4cb 100644
--- a/ggml/src/vulkan-shaders/rope_head.comp
+++ b/ggml/src/vulkan-shaders/rope_head.comp
@@ -1,6 +1,11 @@
#include "types.comp"
#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_spirv_intrinsics: enable
+
+#if RTE16
+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
+#endif
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
@@ -20,6 +25,11 @@ layout (push_constant) uniform parameter {
float corr_dims[2];
float theta_scale;
uint has_ff;
+ uint ne02;
+ uint s1;
+ uint s2;
+ int sections[4];
+ uint is_back;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -39,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
}
+ // Backprogagation uses inverted rotation
+ if (p.is_back != 0) {
+ theta = -theta;
+ }
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}
diff --git a/ggml/src/vulkan-shaders/rope_multi.comp b/ggml/src/vulkan-shaders/rope_multi.comp
new file mode 100644
index 00000000..4f5b1a0e
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rope_multi.comp
@@ -0,0 +1,60 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+ const uint i0 = 2*gl_GlobalInvocationID.y;
+ uint ne0 = p.ncols;
+ uint ne1 = p.p_delta_rows;
+ uint ne2 = p.ne02;
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const uint row_dst = gl_GlobalInvocationID.x;
+
+ if (i0 >= p.n_dims) {
+ const uint i = row_dst*ne0 + i0;
+
+ data_d[i + 0] = data_a[i + 0];
+ data_d[i + 1] = data_a[i + 1];
+
+ return;
+ }
+
+ const uint row_x = row_dst % ne1;
+ const uint channel_x = row_dst / ne1;
+
+ const uint idst = row_dst*ne0 + i0/2;
+ const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
+
+ const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
+ const int sec_w = p.sections[1] + p.sections[0];
+ const uint sector = (i0 / 2) % sect_dims;
+
+ float theta_base = 0.0;
+ if (sector < p.sections[0]) {
+ theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+ }
+ else if (sector >= p.sections[0] && sector < sec_w) {
+ theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
+ theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w + p.sections[2]) {
+ theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ }
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
+
+ const float x0 = float(data_a[ix + 0]);
+ const float x1 = float(data_a[ix + p.n_dims/2]);
+
+ data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}
diff --git a/ggml/src/vulkan-shaders/rope_neox.comp b/ggml/src/vulkan-shaders/rope_neox.comp
index 83b46b69..db775c45 100644
--- a/ggml/src/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/vulkan-shaders/rope_neox.comp
@@ -3,15 +3,18 @@
#include "rope_head.comp"
void main() {
- const uint col = gl_GlobalInvocationID.y * 2;
- const uint row = gl_GlobalInvocationID.x;
+ const uint i0 = 2*gl_GlobalInvocationID.y;
+ uint ne0 = p.ncols;
+ uint ne1 = p.p_delta_rows;
- if (col >= p.ncols) {
+ if (i0 >= ne0) {
return;
}
- if (col >= p.n_dims) {
- const uint i = row*p.ncols + col;
+ const uint row_dst = gl_GlobalInvocationID.x;
+
+ if (i0 >= p.n_dims) {
+ const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@@ -19,19 +22,22 @@ void main() {
return;
}
- const uint i = row*p.ncols + col/2;
- const uint i2 = row/p.p_delta_rows;
+ const uint row_x = row_dst % ne1;
+ const uint channel_x = row_dst / ne1;
+
+ const uint idst = row_dst*ne0 + i0/2;
+ const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
- const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+ const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
- const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+ const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
- rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+ rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
- const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + p.n_dims/2]);
+ const float x0 = float(data_a[ix + 0]);
+ const float x1 = float(data_a[ix + p.n_dims/2]);
- data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+ data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
diff --git a/ggml/src/vulkan-shaders/rope_norm.comp b/ggml/src/vulkan-shaders/rope_norm.comp
index e416ad93..4ad35e54 100644
--- a/ggml/src/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/vulkan-shaders/rope_norm.comp
@@ -3,15 +3,18 @@
#include "rope_head.comp"
void main() {
- const uint col = gl_GlobalInvocationID.y * 2;
- const uint row = gl_GlobalInvocationID.x;
+ const uint i0 = 2*gl_GlobalInvocationID.y;
+ uint ne0 = p.ncols;
+ uint ne1 = p.p_delta_rows;
- if (col >= p.ncols) {
+ if (i0 >= ne0) {
return;
}
- if (col >= p.n_dims) {
- const uint i = row*p.ncols + col;
+ const uint row_dst = gl_GlobalInvocationID.x;
+
+ if (i0 >= p.n_dims) {
+ const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@@ -19,19 +22,22 @@ void main() {
return;
}
- const uint i = row*p.ncols + col;
- const uint i2 = row/p.p_delta_rows;
+ const uint row_x = row_dst % ne1;
+ const uint channel_x = row_dst / ne1;
+
+ const uint idst = row_dst*ne0 + i0;
+ const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
- const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+ const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
- const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+ const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
- rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+ rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
- const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + 1]);
+ const float x0 = float(data_a[ix + 0]);
+ const float x1 = float(data_a[ix + 1]);
- data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
+ data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
diff --git a/ggml/src/vulkan-shaders/rope_vision.comp b/ggml/src/vulkan-shaders/rope_vision.comp
new file mode 100644
index 00000000..cedacc4d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rope_vision.comp
@@ -0,0 +1,47 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+ const uint i0 = 2*gl_GlobalInvocationID.y;
+ uint ne0 = p.ncols;
+ uint ne1 = p.p_delta_rows;
+ uint ne2 = p.ne02;
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const uint row_dst = gl_GlobalInvocationID.x;
+
+ const uint row_x = row_dst % ne1;
+ const uint channel_x = row_dst / ne1;
+
+ const uint idst = row_dst*ne0 + i0/2;
+ const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
+
+ const int sect_dims = p.sections[0] + p.sections[1];
+ const int sec_w = p.sections[1] + p.sections[0];
+ const uint sector = (i0 / 2) % sect_dims;
+
+ float theta_base = 0.0;
+ if (sector < p.sections[0]) {
+ const uint p0 = sector;
+ theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
+ }
+ else if (sector >= p.sections[0] && sector < sec_w) {
+ const uint p0 = sector - p.sections[0];
+ theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
+ }
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
+
+ const float x0 = float(data_a[ix + 0]);
+ const float x1 = float(data_a[ix + p.n_dims]);
+
+ data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}
diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp
index 5cd2f668..4663428d 100644
--- a/ggml/src/vulkan-shaders/scale.comp
+++ b/ggml/src/vulkan-shaders/scale.comp
@@ -3,12 +3,22 @@
#include "types.comp"
#include "generic_unary_head.comp"
+const uint num_threads = 128;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 4;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+
+ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
+ idx += num_threads;
+ }
}
diff --git a/ggml/src/vulkan-shaders/sigmoid.comp b/ggml/src/vulkan-shaders/sigmoid.comp
new file mode 100644
index 00000000..5c9e5c35
--- /dev/null
+++ b/ggml/src/vulkan-shaders/sigmoid.comp
@@ -0,0 +1,20 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+ data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i]))));
+}
diff --git a/ggml/src/vulkan-shaders/silu_back.comp b/ggml/src/vulkan-shaders/silu_back.comp
new file mode 100644
index 00000000..f9afa9b1
--- /dev/null
+++ b/ggml/src/vulkan-shaders/silu_back.comp
@@ -0,0 +1,26 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
+layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
+
+ const float xi = float(data_x[i]);
+ const float s = 1.0f / (1.0f + exp(-xi));
+ data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
+}
diff --git a/ggml/src/vulkan-shaders/sin.comp b/ggml/src/vulkan-shaders/sin.comp
new file mode 100644
index 00000000..d7c15a16
--- /dev/null
+++ b/ggml/src/vulkan-shaders/sin.comp
@@ -0,0 +1,17 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
+}
diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp
index 0bd51eca..51fc2dc7 100644
--- a/ggml/src/vulkan-shaders/soft_max.comp
+++ b/ggml/src/vulkan-shaders/soft_max.comp
@@ -1,6 +1,6 @@
#version 450
-#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter
{
@@ -11,14 +11,13 @@ layout (push_constant) uniform parameter
float m0;
float m1;
uint n_head_log2;
+ uint nrows_x;
} p;
#include "types.comp"
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
@@ -26,10 +25,17 @@ layout (binding = 2) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
-void main() {
+// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate
+// over all the columns. The main function tries to pass a constant here,
+// as if it were a template function, to allow unrolling.
+void soft_max(uint num_iters) {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint rowy = rowx % p.KY;
+ const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
+
+ if (rowx >= p.nrows_x) {
+ return;
+ }
float slope = 1.0f;
@@ -46,19 +52,39 @@ void main() {
// Find max
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ // Cache values while we compute the max, so we don't need to read them
+ // again when we're ready to compute exp(x-max).
+ const uint DATA_CACHE_SIZE = 16;
+ FLOAT_TYPE data_cache[DATA_CACHE_SIZE];
+
+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
- if (col >= p.KX) {
- break;
+ FLOAT_TYPE a = FLOAT_TYPE(0);
+ if (col < p.KX) {
+ a = data_a[rowx * p.KX + col];
+ }
+
+ FLOAT_TYPE b = FLOAT_TYPE(0);
+ if (p.KY > 0 && col < p.KX) {
+ b = data_b[rowy * p.KX + col];
}
- max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
+ FLOAT_TYPE v = a * p.scale + slope * b;
+
+ if (col < p.KX) {
+ max_val = max(max_val, v);
+ }
+
+ if (idx < DATA_CACHE_SIZE) {
+ data_cache[idx] = v;
+ }
}
- vals[tid] = max_val;
+ // reduce across the workgroup
+ vals[tid] = max_val;
barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(vals[tid], vals[tid + s]);
}
@@ -68,39 +94,80 @@ void main() {
max_val = vals[0];
barrier();
- // Sum up values
- vals[tid] = FLOAT_TYPE(0.0f);
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ // Compute sum{exp(x - max)}
+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
}
+ // compute exp(a*scale+b*slope), add it to sum, and cache the new value
+ // in data_cache if possible.
const uint i = rowx * p.KX + col;
- const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
- vals[tid] += val;
- data_d[i] = D_TYPE(val);
+ FLOAT_TYPE val;
+ if (idx < DATA_CACHE_SIZE) {
+ val = exp(data_cache[idx] - max_val);
+ } else {
+ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
+ }
+ sum += val;
+ if (idx < DATA_CACHE_SIZE) {
+ data_cache[idx] = val;
+ } else {
+ data_d[i] = D_TYPE(val);
+ }
}
+ // reduce across the workgroup
+ vals[tid] = sum;
barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] += vals[tid + s];
}
barrier();
}
+ sum = vals[0];
- const D_TYPE divisor = D_TYPE(vals[0]);
+ FLOAT_TYPE rcpdivisor = 1.0/sum;
- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
- break;
+ continue;
+ }
+
+ if (idx < DATA_CACHE_SIZE) {
+ data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);
+ } else {
+ data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
}
+ }
+}
- data_d[rowx*p.KX + col] /= divisor;
+void main() {
+ // instantiate the soft_max function for several different
+ // dimensions, to allow loop unrolling
+ uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
+ if (num_blocks > 32) {
+ soft_max(num_blocks);
+ } else if (num_blocks > 16) {
+ soft_max(32);
+ } else if (num_blocks > 8) {
+ soft_max(16);
+ } else if (num_blocks > 4) {
+ soft_max(8);
+ } else if (num_blocks == 4) {
+ soft_max(4);
+ } else if (num_blocks == 3) {
+ soft_max(3);
+ } else if (num_blocks == 2) {
+ soft_max(2);
+ } else if (num_blocks == 1) {
+ soft_max(1);
}
}
diff --git a/ggml/src/vulkan-shaders/soft_max_back.comp b/ggml/src/vulkan-shaders/soft_max_back.comp
new file mode 100644
index 00000000..29bd77d7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/soft_max_back.comp
@@ -0,0 +1,50 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "generic_head.comp"
+#include "types.comp"
+
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+// In this shader Y = softmax(X) and X is not provided as input.
+
+layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_y[];};
+layout (binding = 2) buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE sum_yg[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ FLOAT_TYPE scale = p.param1;
+
+ // partial sums for thread in warp
+ sum_yg[tid] = FLOAT_TYPE(0.0f);
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);
+ const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);
+ sum_yg[tid] += yi * gi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum_yg[tid] += sum_yg[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE dot_yg = sum_yg[0];
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ data_d[row*p.KX + col] = D_TYPE(scale
+ * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)
+ * FLOAT_TYPE(data_y[row*p.KX + col]));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/square.comp b/ggml/src/vulkan-shaders/square.comp
index 1fa118c9..ef43598b 100644
--- a/ggml/src/vulkan-shaders/square.comp
+++ b/ggml/src/vulkan-shaders/square.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_unary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = get_idx();
@@ -10,6 +12,6 @@ void main() {
return;
}
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}
diff --git a/ggml/src/vulkan-shaders/sub.comp b/ggml/src/vulkan-shaders/sub.comp
new file mode 100644
index 00000000..72353cc3
--- /dev/null
+++ b/ggml/src/vulkan-shaders/sub.comp
@@ -0,0 +1,29 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ uint idx = get_idx();
+
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
+
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
+}
diff --git a/ggml/src/vulkan-shaders/tanh.comp b/ggml/src/vulkan-shaders/tanh.comp
index 74630dc7..8a6f868f 100644
--- a/ggml/src/vulkan-shaders/tanh.comp
+++ b/ggml/src/vulkan-shaders/tanh.comp
@@ -16,6 +16,5 @@ void main() {
if (i >= p.KX) {
return;
}
-
- data_d[i] = D_TYPE(tanh(data_a[i]));
+ data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.));
}
diff --git a/ggml/src/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/vulkan-shaders/test_bfloat16_support.comp
new file mode 100644
index 00000000..fd0ba401
--- /dev/null
+++ b/ggml/src/vulkan-shaders/test_bfloat16_support.comp
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_EXT_bfloat16 : require
+
+void main()
+{
+}
diff --git a/ggml/src/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/vulkan-shaders/test_coopmat2_support.comp
new file mode 100644
index 00000000..28eb24e1
--- /dev/null
+++ b/ggml/src/vulkan-shaders/test_coopmat2_support.comp
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_NV_cooperative_matrix2 : require
+
+void main()
+{
+}
diff --git a/ggml/src/vulkan-shaders/test_coopmat_support.comp b/ggml/src/vulkan-shaders/test_coopmat_support.comp
new file mode 100644
index 00000000..8c5dd1bd
--- /dev/null
+++ b/ggml/src/vulkan-shaders/test_coopmat_support.comp
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_KHR_cooperative_matrix : require
+
+void main()
+{
+}
diff --git a/ggml/src/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/vulkan-shaders/test_integer_dot_support.comp
new file mode 100644
index 00000000..470e3074
--- /dev/null
+++ b/ggml/src/vulkan-shaders/test_integer_dot_support.comp
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_EXT_integer_dot_product : require
+
+void main()
+{
+}
diff --git a/ggml/src/vulkan-shaders/types.comp b/ggml/src/vulkan-shaders/types.comp
index 21dce72f..3bde7178 100644
--- a/ggml/src/vulkan-shaders/types.comp
+++ b/ggml/src/vulkan-shaders/types.comp
@@ -1,6 +1,11 @@
-#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
+#if !defined(GGML_TYPES_COMP)
+#define GGML_TYPES_COMP
+
+#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-#endif
+#extension GL_EXT_shader_16bit_storage : require
#if defined(DATA_A_F32)
#define QUANT_K 1
@@ -28,24 +33,43 @@
#endif
#endif
-#if defined(DATA_A_Q4_0)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
+#if defined(DATA_A_BF16)
+#define QUANT_K 1
+#define QUANT_R 1
+
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
+#define A_TYPE uint16_t
+#elif LOAD_VEC_A == 4
+#define A_TYPE u16vec4
+#elif LOAD_VEC_A == 8
+#error unsupported
+#endif
+#endif
+
+#define QUANT_K_Q4_0 32
+#define QUANT_R_Q4_0 2
struct block_q4_0
{
float16_t d;
uint8_t qs[16];
};
+struct block_q4_0_packed16
+{
+ float16_t d;
+ uint16_t qs[16/2];
+};
+#if defined(DATA_A_Q4_0)
+#define QUANT_K QUANT_K_Q4_0
+#define QUANT_R QUANT_R_Q4_0
+#define QUANT_AUXF 1
#define A_TYPE block_q4_0
+#define A_TYPE_PACKED16 block_q4_0_packed16
#endif
-#if defined(DATA_A_Q4_1)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
+#define QUANT_K_Q4_1 32
+#define QUANT_R_Q4_1 2
struct block_q4_1
{
@@ -54,14 +78,30 @@ struct block_q4_1
uint8_t qs[16];
};
+struct block_q4_1_packed16
+{
+ float16_t d;
+ float16_t m;
+ uint16_t qs[16/2];
+};
+
+struct block_q4_1_packed32
+{
+ f16vec2 dm;
+ uint32_t qs[16/4];
+};
+
+#if defined(DATA_A_Q4_1)
+#define QUANT_K QUANT_K_Q4_1
+#define QUANT_R QUANT_R_Q4_1
+#define QUANT_AUXF 2
#define A_TYPE block_q4_1
+#define A_TYPE_PACKED16 block_q4_1_packed16
+#define A_TYPE_PACKED32 block_q4_1_packed32
#endif
-#if defined(DATA_A_Q5_0)
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
+#define QUANT_K_Q5_0 32
+#define QUANT_R_Q5_0 2
struct block_q5_0
{
@@ -70,14 +110,23 @@ struct block_q5_0
uint8_t qs[16];
};
+struct block_q5_0_packed16
+{
+ float16_t d;
+ uint16_t qh[2];
+ uint16_t qs[16/2];
+};
+
+#if defined(DATA_A_Q5_0)
+#define QUANT_K QUANT_K_Q5_0
+#define QUANT_R QUANT_R_Q5_0
+#define QUANT_AUXF 1
#define A_TYPE block_q5_0
+#define A_TYPE_PACKED16 block_q5_0_packed16
#endif
-#if defined(DATA_A_Q5_1)
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
+#define QUANT_K_Q5_1 32
+#define QUANT_R_Q5_1 2
struct block_q5_1
{
@@ -87,114 +136,1238 @@ struct block_q5_1
uint8_t qs[16];
};
+struct block_q5_1_packed16
+{
+ float16_t d;
+ float16_t m;
+ uint qh;
+ uint16_t qs[16/2];
+};
+
+struct block_q5_1_packed32
+{
+ f16vec2 dm;
+ uint qh;
+ uint32_t qs[16/4];
+};
+
+#if defined(DATA_A_Q5_1)
+#define QUANT_K QUANT_K_Q5_1
+#define QUANT_R QUANT_R_Q5_1
+#define QUANT_AUXF 2
#define A_TYPE block_q5_1
+#define A_TYPE_PACKED16 block_q5_1_packed16
+#define A_TYPE_PACKED32 block_q5_1_packed32
#endif
-#if defined(DATA_A_Q8_0)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 1
+#define QUANT_K_Q8_0 32
+#define QUANT_R_Q8_0 1
struct block_q8_0
{
float16_t d;
int8_t qs[32];
};
+struct block_q8_0_packed16
+{
+ float16_t d;
+ int16_t qs[32/2];
+};
+struct block_q8_0_packed32
+{
+ float16_t d;
+ int32_t qs[32/4];
+};
+#if defined(DATA_A_Q8_0)
+#define QUANT_K QUANT_K_Q8_0
+#define QUANT_R QUANT_R_Q8_0
+#define QUANT_AUXF 1
#define A_TYPE block_q8_0
+#define A_TYPE_PACKED16 block_q8_0_packed16
+#define A_TYPE_PACKED32 block_q8_0_packed32
#endif
+#define QUANT_K_Q8_1 32
+#define QUANT_R_Q8_1 1
+
+struct block_q8_1
+{
+ f16vec2 ds;
+ int8_t qs[32];
+};
+struct block_q8_1_packed16
+{
+ f16vec2 ds;
+ int16_t qs[16];
+};
+struct block_q8_1_packed32
+{
+ f16vec2 ds;
+ int32_t qs[8];
+};
+
// K-quants
-#if defined(DATA_A_Q2_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
+#define QUANT_K_Q2_K 256
struct block_q2_K
{
- uint8_t scales[QUANT_K/16];
- uint8_t qs[QUANT_K/4];
+ uint8_t scales[QUANT_K_Q2_K/16];
+ uint8_t qs[QUANT_K_Q2_K/4];
+ f16vec2 d;
+};
+
+struct block_q2_K_packed16
+{
+ uint16_t scales[QUANT_K_Q2_K/16/2];
+ uint16_t qs[QUANT_K_Q2_K/4/2];
f16vec2 d;
};
+struct block_q2_K_packed32
+{
+ uint32_t scales[QUANT_K_Q2_K/16/4];
+ uint32_t qs[QUANT_K_Q2_K/4/4];
+ f16vec2 d;
+};
+
+#if defined(DATA_A_Q2_K)
+#define QUANT_K QUANT_K_Q2_K
#define A_TYPE block_q2_K
+#define A_TYPE_PACKED16 block_q2_K_packed16
+#define A_TYPE_PACKED32 block_q2_K_packed32
#endif
-#if defined(DATA_A_Q3_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
+#define QUANT_K_Q3_K 256
struct block_q3_K
{
- uint8_t hmask[QUANT_K/8];
- uint8_t qs[QUANT_K/4];
+ uint8_t hmask[QUANT_K_Q3_K/8];
+ uint8_t qs[QUANT_K_Q3_K/4];
uint8_t scales[12];
float16_t d;
};
+struct block_q3_K_packed16
+{
+ uint16_t hmask[QUANT_K_Q3_K/8/2];
+ uint16_t qs[QUANT_K_Q3_K/4/2];
+ uint16_t scales[12/2];
+ float16_t d;
+};
+
+#if defined(DATA_A_Q3_K)
+#define QUANT_K QUANT_K_Q3_K
#define A_TYPE block_q3_K
+#define A_TYPE_PACKED16 block_q3_K_packed16
#endif
-#if defined(DATA_A_Q4_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
+#define QUANT_K_Q4_K 256
struct block_q4_K
{
f16vec2 d;
- uint8_t scales[3*QUANT_K/64];
- uint8_t qs[QUANT_K/2];
+ uint8_t scales[3*QUANT_K_Q4_K/64];
+ uint8_t qs[QUANT_K_Q4_K/2];
+};
+
+struct block_q4_K_packed16
+{
+ f16vec2 d;
+ uint16_t scales[3*QUANT_K_Q4_K/64/2];
+ uint16_t qs[QUANT_K_Q4_K/2/2];
+};
+
+struct block_q4_K_packed32
+{
+ f16vec2 d;
+ uint32_t scales[3*QUANT_K_Q4_K/64/4];
+ uint32_t qs[QUANT_K_Q4_K/2/4];
+};
+
+struct block_q4_K_packed128
+{
+ uvec4 q4k[9];
};
+#if defined(DATA_A_Q4_K)
+#define QUANT_K QUANT_K_Q4_K
#define A_TYPE block_q4_K
+#define A_TYPE_PACKED16 block_q4_K_packed16
+#define A_TYPE_PACKED32 block_q4_K_packed32
#endif
-#if defined(DATA_A_Q5_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
+#define QUANT_K_Q5_K 256
struct block_q5_K
{
f16vec2 d;
uint8_t scales[12];
- uint8_t qh[QUANT_K/8];
- uint8_t qs[QUANT_K/2];
+ uint8_t qh[QUANT_K_Q5_K/8];
+ uint8_t qs[QUANT_K_Q5_K/2];
+};
+
+struct block_q5_K_packed16
+{
+ f16vec2 d;
+ uint16_t scales[12/2];
+ uint16_t qh[QUANT_K_Q5_K/8/2];
+ uint16_t qs[QUANT_K_Q5_K/2/2];
};
+struct block_q5_K_packed128
+{
+ uvec4 q5k[11];
+};
+
+#if defined(DATA_A_Q5_K)
+#define QUANT_K QUANT_K_Q5_K
#define A_TYPE block_q5_K
+#define A_TYPE_PACKED16 block_q5_K_packed16
#endif
-#if defined(DATA_A_Q6_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
+#define QUANT_K_Q6_K 256
struct block_q6_K
{
- uint8_t ql[QUANT_K/2];
- uint8_t qh[QUANT_K/4];
- int8_t scales[QUANT_K/16];
+ uint8_t ql[QUANT_K_Q6_K/2];
+ uint8_t qh[QUANT_K_Q6_K/4];
+ int8_t scales[QUANT_K_Q6_K/16];
float16_t d;
};
+struct block_q6_K_packed16
+{
+ uint16_t ql[QUANT_K_Q6_K/2/2];
+ uint16_t qh[QUANT_K_Q6_K/4/2];
+ int8_t scales[QUANT_K_Q6_K/16];
+ float16_t d;
+};
+
+#if defined(DATA_A_Q6_K)
+#define QUANT_K QUANT_K_Q6_K
#define A_TYPE block_q6_K
+#define A_TYPE_PACKED16 block_q6_K_packed16
#endif
// IQuants
-#if defined(DATA_A_IQ4_NL)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
+#define QUANT_K_IQ1_S 256
+#define QUANT_R_IQ1_S 1
+
+struct block_iq1_s {
+ float16_t d;
+ uint8_t qs[QUANT_K_IQ1_S/8];
+ uint16_t qh[QUANT_K_IQ1_S/32];
+};
+
+#define QUANT_K_IQ1_M 256
+#define QUANT_R_IQ1_M 1
+
+struct block_iq1_m {
+ uint8_t qs[QUANT_K_IQ1_M/8];
+ uint8_t qh[QUANT_K_IQ1_M/16];
+ uint16_t scales[QUANT_K_IQ1_M/64];
+};
+
+struct block_iq1_m_packed64 {
+ uint64_t qs[QUANT_K_IQ1_M/8/8];
+ uint64_t qh[QUANT_K_IQ1_M/16/8];
+ uint64_t scales;
+};
+
+#if defined(DATA_A_IQ1_S)
+#define QUANT_K QUANT_K_IQ1_S
+#define QUANT_R QUANT_R_IQ1_S
+#define A_TYPE block_iq1_s
+#endif
+
+#if defined(DATA_A_IQ1_M)
+#define QUANT_K QUANT_K_IQ1_M
+#define QUANT_R QUANT_R_IQ1_M
+#define A_TYPE block_iq1_m
+#endif
+
+#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
+#define IQ1S_DELTA 0.125f
+#define IQ1M_DELTA 0.125f
+
+// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate).
+const uint[1024] iq1s_grid_const = {
+ 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,
+ 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,
+ 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,
+ 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,
+ 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,
+ 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,
+ 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,
+ 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,
+ 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,
+ 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,
+ 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,
+ 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,
+ 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,
+ 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,
+ 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,
+ 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,
+ 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,
+ 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,
+ 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,
+ 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,
+ 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,
+ 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,
+ 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,
+ 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,
+ 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,
+ 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,
+ 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,
+ 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,
+ 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,
+ 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,
+ 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,
+ 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,
+ 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,
+ 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,
+ 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,
+ 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,
+ 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,
+ 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,
+ 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,
+ 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,
+ 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,
+ 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,
+ 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,
+ 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,
+ 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,
+ 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,
+ 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,
+ 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,
+ 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,
+ 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,
+ 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,
+ 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,
+ 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,
+ 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,
+ 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,
+ 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,
+ 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,
+ 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,
+ 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,
+ 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,
+ 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,
+ 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,
+ 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,
+ 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,
+ 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,
+ 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,
+ 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,
+ 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,
+ 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,
+ 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,
+ 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,
+ 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,
+ 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,
+ 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,
+ 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,
+ 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,
+ 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,
+ 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,
+ 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,
+ 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,
+ 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,
+ 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,
+ 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,
+ 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,
+ 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,
+ 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,
+ 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,
+ 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,
+ 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,
+ 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,
+ 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,
+ 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,
+ 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,
+ 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,
+ 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,
+ 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,
+ 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,
+ 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,
+ 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,
+ 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,
+ 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,
+ 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,
+ 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,
+ 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,
+ 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,
+ 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,
+ 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,
+ 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,
+ 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,
+ 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,
+ 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,
+ 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,
+ 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,
+ 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,
+ 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,
+ 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,
+ 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,
+ 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,
+ 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,
+ 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,
+ 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,
+ 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,
+ 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,
+ 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,
+ 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,
+ 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,
+ 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,
+ 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
+};
+
+shared uint16_t iq1s_grid[2048];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) {
+ uint idx = i + gl_LocalInvocationIndex.x;
+ if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) {
+ u16vec2 g = unpack16(iq1s_grid_const[idx]);
+ iq1s_grid[2*idx+0] = g.x;
+ iq1s_grid[2*idx+1] = g.y;
+ }
+ }
+ barrier();
+}
+#endif
+
+#define QUANT_K_IQ2_XXS 256
+#define QUANT_R_IQ2_XXS 1
+
+struct block_iq2_xxs
+{
+ float16_t d;
+ uint8_t qs[QUANT_K_IQ2_XXS/4];
+};
+
+struct block_iq2_xxs_packed16
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ2_XXS/8];
+};
+
+#if defined(DATA_A_IQ2_XXS)
+
+const uvec2[256] iq2xxs_grid_const = {
+ uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
+ uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808),
+ uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808),
+ uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808),
+ uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808),
+ uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819),
+ uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819),
+ uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b),
+ uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908),
+ uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908),
+ uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908),
+ uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919),
+ uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919),
+ uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b),
+ uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08),
+ uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08),
+ uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08),
+ uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b),
+ uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808),
+ uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808),
+ uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819),
+ uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b),
+ uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908),
+ uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919),
+ uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08),
+ uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08),
+ uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b),
+ uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808),
+ uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819),
+ uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908),
+ uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908),
+ uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b),
+ uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b),
+ uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808),
+ uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808),
+ uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808),
+ uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819),
+ uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b),
+ uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908),
+ uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908),
+ uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08),
+ uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08),
+ uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19),
+ uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808),
+ uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808),
+ uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819),
+ uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919),
+ uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08),
+ uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b),
+ uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808),
+ uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b),
+ uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919),
+ uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808),
+ uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819),
+ uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908),
+ uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908),
+ uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919),
+ uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08),
+ uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808),
+ uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819),
+ uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908),
+ uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19),
+ uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819),
+ uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19)
+};
+
+shared uvec2 iq2xxs_grid[256];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) {
+ if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) {
+ iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x];
+ }
+ }
+ barrier();
+}
+
+#define QUANT_K QUANT_K_IQ2_XXS
+#define QUANT_R QUANT_R_IQ2_XXS
+#define A_TYPE block_iq2_xxs
+#define A_TYPE_PACKED16 block_iq2_xxs_packed16
+#endif
+
+#define QUANT_K_IQ2_XS 256
+#define QUANT_R_IQ2_XS 1
+
+struct block_iq2_xs
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ2_XS/8];
+ uint8_t scales[QUANT_K_IQ2_XS/32];
+};
+
+struct block_iq2_xs_packed16
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ2_XS/8];
+ uint16_t scales[QUANT_K_IQ2_XS/64];
+};
+
+#if defined(DATA_A_IQ2_XS)
+
+const uvec2 iq2xs_grid_const[512] = {
+ uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
+ uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),
+ uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),
+ uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),
+ uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),
+ uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808),
+ uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808),
+ uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819),
+ uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819),
+ uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819),
+ uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819),
+ uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819),
+ uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819),
+ uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b),
+ uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b),
+ uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),
+ uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908),
+ uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908),
+ uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908),
+ uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908),
+ uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908),
+ uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919),
+ uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919),
+ uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),
+ uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b),
+ uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b),
+ uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),
+ uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08),
+ uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08),
+ uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08),
+ uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19),
+ uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19),
+ uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b),
+ uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808),
+ uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808),
+ uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808),
+ uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808),
+ uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808),
+ uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819),
+ uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),
+ uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819),
+ uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b),
+ uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b),
+ uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908),
+ uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908),
+ uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),
+ uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919),
+ uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b),
+ uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08),
+ uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08),
+ uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b),
+ uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808),
+ uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808),
+ uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808),
+ uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819),
+ uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819),
+ uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b),
+ uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908),
+ uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919),
+ uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b),
+ uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08),
+ uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19),
+ uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b),
+ uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808),
+ uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808),
+ uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808),
+ uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808),
+ uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808),
+ uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808),
+ uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819),
+ uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819),
+ uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819),
+ uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b),
+ uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908),
+ uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908),
+ uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908),
+ uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908),
+ uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919),
+ uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b),
+ uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08),
+ uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08),
+ uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19),
+ uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808),
+ uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808),
+ uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808),
+ uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819),
+ uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819),
+ uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b),
+ uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908),
+ uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908),
+ uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919),
+ uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08),
+ uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08),
+ uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b),
+ uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808),
+ uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808),
+ uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b),
+ uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919),
+ uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08),
+ uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b),
+ uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808),
+ uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808),
+ uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808),
+ uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819),
+ uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819),
+ uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b),
+ uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b),
+ uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),
+ uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908),
+ uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b),
+ uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08),
+ uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08),
+ uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b),
+ uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808),
+ uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808),
+ uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b),
+ uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908),
+ uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919),
+ uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08),
+ uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808),
+ uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808),
+ uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819),
+ uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b),
+ uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908),
+ uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08),
+ uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08),
+ uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19),
+ uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b),
+};
+
+shared uvec2 iq2xs_grid[512];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) {
+ if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) {
+ iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x];
+ }
+ }
+ barrier();
+}
+
+#define QUANT_K QUANT_K_IQ2_XS
+#define QUANT_R QUANT_R_IQ2_XS
+#define A_TYPE block_iq2_xs
+#define A_TYPE_PACKED16 block_iq2_xs_packed16
+#endif
+
+#define QUANT_K_IQ2_S 256
+#define QUANT_R_IQ2_S 1
+
+struct block_iq2_s
+{
+ float16_t d;
+ uint8_t qs[QUANT_K_IQ2_S/4];
+ uint8_t qh[QUANT_K_IQ2_S/32];
+ uint8_t scales[QUANT_K_IQ2_S/32];
+};
+
+struct block_iq2_s_packed16
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ2_S/8];
+ uint16_t qh[QUANT_K_IQ2_S/64];
+ uint16_t scales[QUANT_K_IQ2_S/64];
+};
+
+#if defined(DATA_A_IQ2_S)
+
+const uvec2 iq2s_grid_const[1024] = {
+ uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
+ uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),
+ uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),
+ uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),
+ uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),
+ uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808),
+ uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808),
+ uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808),
+ uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819),
+ uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819),
+ uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819),
+ uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819),
+ uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819),
+ uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819),
+ uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819),
+ uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b),
+ uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b),
+ uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b),
+ uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),
+ uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b),
+ uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908),
+ uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908),
+ uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908),
+ uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908),
+ uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908),
+ uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908),
+ uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908),
+ uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908),
+ uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919),
+ uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919),
+ uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919),
+ uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),
+ uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919),
+ uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919),
+ uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919),
+ uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b),
+ uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b),
+ uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b),
+ uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b),
+ uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),
+ uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08),
+ uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08),
+ uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08),
+ uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08),
+ uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08),
+ uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19),
+ uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19),
+ uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19),
+ uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19),
+ uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b),
+ uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b),
+ uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808),
+ uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808),
+ uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808),
+ uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808),
+ uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808),
+ uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808),
+ uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808),
+ uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808),
+ uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819),
+ uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),
+ uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819),
+ uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819),
+ uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819),
+ uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819),
+ uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819),
+ uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b),
+ uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b),
+ uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b),
+ uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b),
+ uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908),
+ uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908),
+ uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908),
+ uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),
+ uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908),
+ uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908),
+ uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908),
+ uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919),
+ uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919),
+ uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919),
+ uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919),
+ uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919),
+ uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b),
+ uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b),
+ uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08),
+ uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08),
+ uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08),
+ uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08),
+ uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08),
+ uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19),
+ uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19),
+ uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19),
+ uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b),
+ uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808),
+ uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808),
+ uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808),
+ uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808),
+ uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808),
+ uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819),
+ uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819),
+ uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819),
+ uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819),
+ uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b),
+ uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b),
+ uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908),
+ uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908),
+ uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908),
+ uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908),
+ uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908),
+ uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919),
+ uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919),
+ uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919),
+ uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b),
+ uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08),
+ uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08),
+ uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19),
+ uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b),
+ uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808),
+ uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808),
+ uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808),
+ uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808),
+ uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808),
+ uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808),
+ uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808),
+ uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808),
+ uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819),
+ uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819),
+ uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819),
+ uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819),
+ uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819),
+ uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819),
+ uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819),
+ uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b),
+ uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b),
+ uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b),
+ uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b),
+ uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908),
+ uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908),
+ uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908),
+ uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908),
+ uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908),
+ uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908),
+ uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908),
+ uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919),
+ uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919),
+ uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919),
+ uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919),
+ uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919),
+ uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919),
+ uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b),
+ uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b),
+ uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b),
+ uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08),
+ uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08),
+ uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08),
+ uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08),
+ uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19),
+ uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19),
+ uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19),
+ uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b),
+ uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808),
+ uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808),
+ uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808),
+ uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808),
+ uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808),
+ uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808),
+ uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808),
+ uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819),
+ uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819),
+ uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819),
+ uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819),
+ uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819),
+ uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b),
+ uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b),
+ uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b),
+ uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908),
+ uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908),
+ uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908),
+ uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908),
+ uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908),
+ uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919),
+ uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919),
+ uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919),
+ uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b),
+ uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08),
+ uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08),
+ uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08),
+ uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19),
+ uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b),
+ uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808),
+ uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808),
+ uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808),
+ uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808),
+ uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819),
+ uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819),
+ uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819),
+ uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b),
+ uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908),
+ uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908),
+ uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908),
+ uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919),
+ uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919),
+ uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08),
+ uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19),
+ uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808),
+ uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808),
+ uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808),
+ uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808),
+ uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808),
+ uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819),
+ uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819),
+ uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819),
+ uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819),
+ uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819),
+ uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b),
+ uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b),
+ uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908),
+ uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),
+ uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908),
+ uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908),
+ uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908),
+ uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919),
+ uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919),
+ uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919),
+ uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b),
+ uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08),
+ uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08),
+ uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19),
+ uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b),
+ uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808),
+ uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808),
+ uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808),
+ uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808),
+ uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808),
+ uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819),
+ uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819),
+ uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b),
+ uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908),
+ uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908),
+ uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908),
+ uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919),
+ uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919),
+ uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08),
+ uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08),
+ uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19),
+ uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808),
+ uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808),
+ uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808),
+ uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b),
+ uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b),
+ uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908),
+ uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908),
+ uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b),
+ uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19),
+ uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b),
+ uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b)
+};
+
+shared uvec2 iq2s_grid[1024];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) {
+ if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) {
+ iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x];
+ }
+ }
+ barrier();
+}
+
+#define QUANT_K QUANT_K_IQ2_S
+#define QUANT_R QUANT_R_IQ2_S
+#define A_TYPE block_iq2_s
+#define A_TYPE_PACKED16 block_iq2_s_packed16
+#endif
+
+#define QUANT_K_IQ3_XXS 256
+#define QUANT_R_IQ3_XXS 1
+
+struct block_iq3_xxs
+{
+ float16_t d;
+ uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8];
+};
+
+struct block_iq3_xxs_packed16
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16];
+};
+
+#if defined(DATA_A_IQ3_XXS)
+
+const uint32_t iq3xxs_grid_const[256] = {
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+};
+
+shared uint32_t iq3xxs_grid[256];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) {
+ if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) {
+ iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x];
+ }
+ }
+ barrier();
+}
+
+#define QUANT_K QUANT_K_IQ3_XXS
+#define QUANT_R QUANT_R_IQ3_XXS
+#define A_TYPE block_iq3_xxs
+#define A_TYPE_PACKED16 block_iq3_xxs_packed16
+#endif
+
+#define QUANT_K_IQ3_S 256
+#define QUANT_R_IQ3_S 1
+
+struct block_iq3_s
+{
+ float16_t d;
+ uint8_t qs[QUANT_K_IQ3_S/4];
+ uint8_t qh[QUANT_K_IQ3_S/32];
+ uint8_t signs[QUANT_K_IQ3_S/8];
+ uint8_t scales[QUANT_K_IQ3_S/64];
+};
+
+struct block_iq3_s_packed16
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ3_S/4/2];
+ uint16_t qh[QUANT_K_IQ3_S/32/2];
+ uint16_t signs[QUANT_K_IQ3_S/8/2];
+ uint16_t scales[QUANT_K_IQ3_S/64/2];
+};
+
+#if defined(DATA_A_IQ3_S)
+
+const uint32_t iq3s_grid_const[512] = {
+ 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
+ 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
+ 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
+ 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
+ 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
+ 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
+ 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
+ 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
+ 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
+ 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
+ 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
+ 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
+ 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
+ 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
+ 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
+ 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
+ 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
+ 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
+ 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
+ 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
+ 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
+ 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
+ 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
+ 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
+ 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
+ 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
+ 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
+ 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
+ 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
+ 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
+ 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
+ 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
+ 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
+ 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
+ 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
+ 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
+ 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
+ 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
+ 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
+ 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
+ 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
+ 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
+ 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
+ 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
+ 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
+ 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
+ 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
+ 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
+ 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
+ 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
+ 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
+ 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
+ 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
+ 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
+ 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
+ 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
+ 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
+ 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
+ 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
+ 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
+ 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
+ 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
+ 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
+ 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
+};
+
+shared uint32_t iq3s_grid[512];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) {
+ if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) {
+ iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x];
+ }
+ }
+ barrier();
+}
+
+#define QUANT_K QUANT_K_IQ3_S
+#define QUANT_R QUANT_R_IQ3_S
+#define A_TYPE block_iq3_s
+#define A_TYPE_PACKED16 block_iq3_s_packed16
+#endif
+
+#define QUANT_K_IQ4_XS 256
+#define QUANT_R_IQ4_XS 1
+
+struct block_iq4_xs
+{
+ float16_t d;
+ uint16_t scales_h;
+ uint8_t scales_l[QUANT_K_IQ4_XS/64];
+ uint8_t qs[QUANT_K_IQ4_XS/2];
+};
+
+#if defined(DATA_A_IQ4_XS)
+#define QUANT_K QUANT_K_IQ4_XS
+#define QUANT_R QUANT_R_IQ4_XS
+#define A_TYPE block_iq4_xs
+#endif
+
+#define QUANT_K_IQ4_NL 32
+#define QUANT_R_IQ4_NL 2
struct block_iq4_nl
{
float16_t d;
- uint8_t qs[QUANT_K/2];
+ uint8_t qs[QUANT_K_IQ4_NL/2];
+};
+
+struct block_iq4_nl_packed16
+{
+ float16_t d;
+ uint16_t qs[QUANT_K_IQ4_NL/2/2];
};
+#if defined(DATA_A_IQ4_NL)
+#define QUANT_K QUANT_K_IQ4_NL
+#define QUANT_R QUANT_R_IQ4_NL
#define A_TYPE block_iq4_nl
+#define A_TYPE_PACKED16 block_iq4_nl_packed16
+#endif
-const int8_t kvalues_iq4nl[16] = {
+#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
+const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
};
+
+shared FLOAT_TYPE kvalues_iq4nl[16];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) {
+ kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]);
+ }
+ barrier();
+}
#endif
+
+// returns the bfloat value in the low 16b.
+// See ggml_compute_fp32_to_bf16
+uint32_t fp32_to_bf16(float f)
+{
+ uint32_t u = floatBitsToUint(f);
+ u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
+ return u;
+}
+
+float bf16_to_fp32(uint32_t u)
+{
+ return uintBitsToFloat(u << 16);
+}
+
+#endif // !defined(GGML_TYPES_COMP)
diff --git a/ggml/src/vulkan-shaders/upscale.comp b/ggml/src/vulkan-shaders/upscale.comp
index 511a086e..6f607380 100644
--- a/ggml/src/vulkan-shaders/upscale.comp
+++ b/ggml/src/vulkan-shaders/upscale.comp
@@ -2,7 +2,7 @@
layout (push_constant) uniform parameter
{
- uint ne; uint d_offset;
+ uint ne; uint a_offset; uint d_offset;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
@@ -32,5 +32,5 @@ void main() {
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
- data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
+ data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
}
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
index a792e203..0f244dea 100644
--- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -16,13 +16,14 @@
#include <cstdio>
#include <cstring>
#include <cstdlib>
+#include <cassert>
+#include <algorithm>
#include <sys/stat.h>
#include <sys/types.h>
#ifdef _WIN32
#include <windows.h>
#include <direct.h> // For _mkdir on Windows
- #include <algorithm> // For std::replace on w64devkit
#else
#include <unistd.h>
#include <sys/wait.h>
@@ -54,9 +55,19 @@ const std::vector<std::string> type_names = {
"q4_k",
"q5_k",
"q6_k",
- "iq4_nl"
+ "iq1_s",
+ "iq1_m",
+ "iq2_xxs",
+ "iq2_xs",
+ "iq2_s",
+ "iq3_xxs",
+ "iq3_s",
+ "iq4_xs",
+ "iq4_nl",
+ "bf16",
};
+namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
@@ -74,7 +85,8 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
}
PROCESS_INFORMATION pi;
- STARTUPINFOA si = { sizeof(STARTUPINFOA) };
+ STARTUPINFOA si = {};
+ si.cb = sizeof(STARTUPINFOA);
si.dwFlags = STARTF_USESTDHANDLES;
si.hStdOutput = stdout_write;
si.hStdError = stderr_write;
@@ -92,11 +104,11 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
std::array<char, 128> buffer;
DWORD bytes_read;
- while (ReadFile(stdout_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
+ while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
stdout_str.append(buffer.data(), bytes_read);
}
- while (ReadFile(stderr_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
+ while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
stderr_str.append(buffer.data(), bytes_read);
}
@@ -173,6 +185,13 @@ std::string to_uppercase(const std::string& input) {
return result;
}
+bool string_starts_with(const std::string& str, const std::string& prefix) {
+ if (prefix.size() > str.size()) {
+ return false;
+ }
+ return std::equal(prefix.begin(), prefix.end(), str.begin());
+}
+
bool string_ends_with(const std::string& str, const std::string& suffix) {
if (suffix.size() > str.size()) {
return false;
@@ -190,16 +209,31 @@ std::string basename(const std::string &path) {
return path.substr(path.find_last_of("/\\") + 1);
}
-void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
- std::string name = _name + (fp16 ? "" : "_fp32");
+// variables to track number of compiles in progress
+static uint32_t compile_count = 0;
+static std::mutex compile_count_mutex;
+static std::condition_variable compile_count_cond;
+
+void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname);
+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
+
+ // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
+ std::string opt_level = coopmat ? "" : "-O";
+
#ifdef _WIN32
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
#else
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
#endif
+
+ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
+ cmd.push_back("-g");
+ #endif
+
for (const auto& define : defines) {
cmd.push_back("-D" + define.first + "=" + define.second);
}
@@ -228,6 +262,12 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
}
+ {
+ std::lock_guard<std::mutex> guard(compile_count_mutex);
+ assert(compile_count > 0);
+ compile_count--;
+ }
+ compile_count_cond.notify_all();
}
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
@@ -236,12 +276,29 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
return result;
}
-void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id) {
- std::string load_vec = fp16 ? "8" : "4";
- std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
- std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
+static std::vector<std::future<void>> compiles;
+void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
+ {
+ // wait until fewer than N compiles are in progress.
+ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
+ uint32_t N = 16;
+ std::unique_lock<std::mutex> guard(compile_count_mutex);
+ while (compile_count >= N) {
+ compile_count_cond.wait(guard);
+ }
+ compile_count++;
+ }
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
+}
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
+void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
+
+ std::map<std::string, std::string> base_dict = {
+ {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
+ };
std::string shader_name = "matmul";
if (matmul_id) {
@@ -253,225 +310,328 @@ void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmu
base_dict["FLOAT16"] = "1";
}
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+
+ if (coopmat) {
+ base_dict["COOPMAT"] = "1";
+ }
+
+ const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
+
+ auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
+ if (t == "bf16") {
+ // scalar path promotes to float
+ if (!coopmat && !coopmat2) {
+ return "float";
+ }
+ return "bfloat16_t";
+ }
+ if (coopmat2 || fp16) {
+ return "float16_t";
+ }
+ return "float";
+ };
+
// Shaders with f16 B_TYPE
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
- }));
-
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
- }));
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+
+ // bf16
+ {
+ std::string load_vec_a_unaligned = "1";
+ // For aligned matmul loads
+ std::string load_vec_a = coopmat2 ? "1" : "4";
+
+ // scalar path promotes to float
+ std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
+
+ // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
+#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+ if (!(coopmat || coopmat2))
+#endif
+ {
+ string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ }
+ }
for (const auto& tname : type_names) {
+ std::string load_vec_quant = "2";
+ if ((tname == "q4_0") || (tname == "q4_1"))
+ load_vec_quant = "8";
+ else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
+ load_vec_quant = "4";
+
+ if (tname == "bf16") {
+ continue;
+ }
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
- std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
- std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
- }));
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
+
+ // don't generate f32 variants for coopmat2
+ if (!coopmat2) {
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ }
+
+ if (tname != "f16" && tname != "f32") {
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ }
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
+ string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
+ }
+#endif
}
}
-void process_shaders(std::vector<std::future<void>>& tasks) {
+void process_shaders() {
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
- for (const auto& fp16 : {false, true}) {
- matmul_shaders(tasks, fp16, false);
- matmul_shaders(tasks, fp16, true);
+ // matmul
+ for (const auto& matmul_id : {false, true}) {
+ // No coopmats
+ // fp32
+ matmul_shaders(false, matmul_id, false, false, false);
+
+ // fp16, fp32acc and fp16acc
+ matmul_shaders(true, matmul_id, false, false, false);
+ matmul_shaders(true, matmul_id, false, false, true);
+
+#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+ // Coopmat, fp32acc and fp16acc
+ matmul_shaders(true, matmul_id, true, false, false);
+ matmul_shaders(true, matmul_id, true, false, true);
+#endif
+
+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+ // Coopmat2, fp32acc and fp16acc
+ matmul_shaders(true, matmul_id, false, true, false);
+ matmul_shaders(true, matmul_id, false, true, true);
+#endif
+ }
+
+ // flash attention
+ for (const auto& f16acc : {false, true}) {
+ std::string acctype = f16acc ? "float16_t" : "float";
+ std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
+
+ for (const auto& tname : type_names) {
+ if (tname == "f32") {
+ continue;
+ }
+ if (tname == "bf16") continue;
+
+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+ if (tname == "f16") {
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
+ } else {
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
+ }
+#endif
+#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+ if (tname == "f16") {
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
+ } else if (tname == "q4_0" || tname == "q8_0") {
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
+ }
+#endif
+ if (tname == "f16") {
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
+ } else if (tname == "q4_0" || tname == "q8_0") {
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+ }
+ }
}
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
- std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
+ std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
- }));
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
// Dequant shaders
- if (tname != "f16") {
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
- }));
+ if (tname != "f16" && tname != "bf16") {
+ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
}
if (!string_ends_with(tname, "_k")) {
- shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
+ shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") {
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
- }));
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
} else {
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
- }));
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
}
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
- }));
+ string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
}
}
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
+ string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
// Norms
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- }));
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [] {
- string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
- }));
-
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
-
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
- }));
-
- tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- }));
+ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
+ string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
+
+ for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
+ string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
+ string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }
+
+ auto get_type_str = [](bool f16) {
+ return f16 ? "float16_t" : "float";
+ };
+ auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
+ std::string s;
+ s += std::string(src0_f16 ? "_f16" : "_f32");
+ s += std::string(src1_f16 ? "_f16" : "_f32");
+ s += std::string(dst_f16 ? "_f16" : "_f32");
+ return s;
+ };
+ for (std::string op : {"add", "sub", "mul", "div"}) {
+ for (auto src0_f16 : {false, true}) {
+ for (auto src1_f16 : {false, true}) {
+ for (auto dst_f16 : {false, true}) {
+ auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
+ string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
+ }
+ }
+ }
+ }
+
+ string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
+ string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
+ string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
+
+ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
+
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
+ string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+ string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+ string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
+
+ string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
+ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
+
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
+ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
+
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+ string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+
+ string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+ string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
+ string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
+ for (auto &c : compiles) {
+ c.wait();
+ }
}
void write_output_files() {
@@ -481,6 +641,7 @@ void write_output_files() {
fprintf(hdr, "#include <cstdint>\n\n");
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
+ std::sort(shader_fnames.begin(), shader_fnames.end());
for (const auto& pair : shader_fnames) {
const std::string& name = pair.first;
#ifdef _WIN32
@@ -522,16 +683,28 @@ void write_output_files() {
std::remove(path.c_str());
}
}
-
+ for (const char *op : {"add", "sub", "mul", "div"}) {
+ fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
+ fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
+ fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
+ fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
+ }
fclose(hdr);
fclose(src);
}
+}
int main(int argc, char** argv) {
std::map<std::string, std::string> args;
- for (int i = 1; i < argc; i += 2) {
- if (i + 1 < argc) {
- args[argv[i]] = argv[i + 1];
+ for (int i = 1; i < argc; ++i) {
+ std::string arg = argv[i];
+ if (arg.rfind("--", 0) == 0) {
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
+ args[arg] = argv[i + 1];
+ ++i;
+ } else {
+ args[arg] = "";
+ }
}
}
@@ -566,12 +739,7 @@ int main(int argc, char** argv) {
}
}
- std::vector<std::future<void>> tasks;
- process_shaders(tasks);
-
- for (auto& task : tasks) {
- task.get();
- }
+ process_shaders();
write_output_files();
diff --git a/ggml/src/vulkan-shaders/wkv6.comp b/ggml/src/vulkan-shaders/wkv6.comp
new file mode 100644
index 00000000..35cc6c45
--- /dev/null
+++ b/ggml/src/vulkan-shaders/wkv6.comp
@@ -0,0 +1,87 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#define BLOCK_SIZE 64
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+ uint B;
+ uint T;
+ uint C;
+ uint H;
+};
+
+layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };
+layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };
+layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };
+layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };
+layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };
+layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };
+layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };
+
+shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];
+
+void main() {
+ const uint head_size = BLOCK_SIZE;
+ const uint batch_id = gl_WorkGroupID.x / H;
+ const uint head_id = gl_WorkGroupID.x % H;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint state_size = C * head_size;
+ const uint n_seq_tokens = T / B;
+
+ if (batch_id >= B || head_id >= H) {
+ return;
+ }
+
+ A_TYPE state[BLOCK_SIZE];
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ + i * head_size + tid];
+ }
+
+ barrier();
+ _tf[tid] = tf[head_id * head_size + tid];
+ barrier();
+
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+ for (uint t = start_t; t < end_t; t += C) {
+ barrier();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ barrier();
+
+ const A_TYPE v_val = v[t];
+ A_TYPE y = 0.0;
+
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+ vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+ vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ vec4 kv = k_vec * v_val;
+
+ vec4 temp = tf_vec * kv + s_vec;
+ y += dot(r_vec, temp);
+
+ s_vec = s_vec * td_vec + kv;
+ state[j] = s_vec.x;
+ state[j+1] = s_vec.y;
+ state[j+2] = s_vec.z;
+ state[j+3] = s_vec.w;
+ }
+
+ dst[t] = y;
+ }
+
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ + i * head_size + tid] = state[i];
+ }
+}