summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/copy_to_quant.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/copy_to_quant.comp')
-rw-r--r--ggml/src/vulkan-shaders/copy_to_quant.comp65
1 files changed, 58 insertions, 7 deletions
diff --git a/ggml/src/vulkan-shaders/copy_to_quant.comp b/ggml/src/vulkan-shaders/copy_to_quant.comp
index 9c76437d..e06547e4 100644
--- a/ggml/src/vulkan-shaders/copy_to_quant.comp
+++ b/ggml/src/vulkan-shaders/copy_to_quant.comp
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
#endif // RTE16
#include "types.comp"
-#include "generic_unary_head.comp"
-#if defined(DATA_A_IQ4_NL)
-// 16 invocations needed for init_iq4nl_shmem
-layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
+#if defined(SET_ROWS) && QUANT_K == 1
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+const uint BLOCK_SIZE = 512;
#else
-layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+const uint BLOCK_SIZE = 32;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
+
+#if defined(SET_ROWS)
+#include "generic_binary_head.comp"
+layout (binding = 1) readonly buffer C {uvec2 data_i[];};
+layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
+#else
+#include "generic_unary_head.comp"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
+#endif
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
}
#endif
+#if defined(DATA_A_F32) || defined(DATA_A_F16)
+void quantize(uint dst_idx, uint src_idx)
+{
+ data_q[dst_idx] = A_TYPE(data_s[src_idx]);
+}
+#endif
+
+#if defined(DATA_A_BF16)
+void quantize(uint dst_idx, uint src_idx)
+{
+ data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
+}
+#endif
+
+#if defined(SET_ROWS)
+
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
- if (gl_LocalInvocationIndex.x != 0) {
+#endif
+
+ const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
+
+ if (idx >= p.ne) {
return;
}
+
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ uint i12 = fastmod(i03, p.ne12);
+ uint i11 = fastmod(i02, p.ne11);
+ uint i10 = i01;
+
+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
+
+ uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
+ uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
+
+ quantize(dst_idx, src0_idx);
+}
+
+#else
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
#endif
- const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
+ const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
@@ -240,3 +289,5 @@ void main() {
quantize(dst_idx, src_idx);
}
+
+#endif