summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/copy_to_quant.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/copy_to_quant.comp')
-rw-r--r--ggml/src/vulkan-shaders/copy_to_quant.comp242
1 files changed, 242 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/copy_to_quant.comp b/ggml/src/vulkan-shaders/copy_to_quant.comp
new file mode 100644
index 00000000..9c76437d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/copy_to_quant.comp
@@ -0,0 +1,242 @@
+#version 450
+
+#if RTE16
+#extension GL_EXT_spirv_intrinsics : enable
+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
+#endif // RTE16
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+#if defined(DATA_A_IQ4_NL)
+// 16 invocations needed for init_iq4nl_shmem
+layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
+#else
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+#endif
+
+layout (binding = 0) readonly buffer S {float data_s[];};
+layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
+
+#if defined(DATA_A_Q4_0)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0;
+ float vmax = 0.0;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
+ const float v = data_s[src_idx + j];
+ if (amax < abs(v)) {
+ amax = abs(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -8;
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
+ const float x0 = data_s[src_idx + 0 + j]*id;
+ const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
+
+ const uint xi0 = min(15, int(x0 + 8.5));
+ const uint xi1 = min(15, int(x1 + 8.5));
+
+ data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
+ }
+}
+#endif
+
+#if defined(DATA_A_Q4_1)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float vmin = 1.0/0.0;
+ float vmax = -vmin;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
+ const float v = data_s[src_idx + j];
+
+ if (v < vmin) vmin = v;
+ if (v > vmax) vmax = v;
+ }
+
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+ data_q[dst_idx].m = float16_t(vmin);
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
+ const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
+ const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
+
+ const uint xi0 = min(15, int(x0 + 0.5));
+ const uint xi1 = min(15, int(x1 + 0.5));
+
+ data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
+ }
+}
+#endif
+
+#if defined(DATA_A_Q5_0)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0;
+ float vmax = 0.0;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
+ const float v = data_s[src_idx + j];
+ if (amax < abs(v)) {
+ amax = abs(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -16;
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+
+ uint32_t qh = 0;
+ [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
+ const float x0 = data_s[src_idx + 0 + j]*id;
+ const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
+
+ const uint xi0 = min(31, int(x0 + 16.5));
+ const uint xi1 = min(31, int(x1 + 16.5));
+
+ data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
+ }
+ data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
+ data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
+}
+#endif
+
+#if defined(DATA_A_Q5_1)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float min = data_s[src_idx + 0];
+ float max = min;
+
+ [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
+ const float v = data_s[src_idx + j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = (d != 0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+ data_q[dst_idx].m = float16_t(min);
+
+ uint32_t qh = 0;
+ [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
+ const float x0 = (data_s[src_idx + 0 + j] - min)*id;
+ const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
+
+ const uint xi0 = uint(x0 + 0.5);
+ const uint xi1 = uint(x1 + 0.5);
+
+ data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
+ }
+ data_q[dst_idx].qh = qh;
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0; // absolute max
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
+ const float v = data_s[src_idx + j];
+ amax = max(amax, abs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ data_q[dst_idx].d = float16_t(d);
+
+ [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
+ const float x0 = data_s[src_idx + j]*id;
+
+ data_q[dst_idx].qs[j] = int8_t(round(x0));
+ }
+}
+#endif
+
+#if defined(DATA_A_IQ4_NL)
+uint best_index(float x) {
+ if (x <= kvalues_iq4nl[0]) return 0;
+ if (x >= kvalues_iq4nl[15]) return 15;
+ int ml = 0, mu = 15;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
+ }
+ return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
+}
+
+void quantize(uint dst_idx, uint src_idx)
+{
+ float amax = 0.0;
+ float vmax = 0.0;
+
+ [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
+ const float v = data_s[src_idx + j];
+ if (amax < abs(v)) {
+ amax = abs(v);
+ vmax = v;
+ }
+ }
+
+ float d = vmax / kvalues_iq4nl[0];
+ const float id = (d != 0.0) ? 1.0/d : 0.0;
+
+ float sumqx = 0, sumq2 = 0;
+ [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
+ const float x0 = data_s[src_idx + 0 + j]*id;
+ const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
+ const uint xi0 = best_index(x0);
+ const uint xi1 = best_index(x1);
+ data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
+ const float v0 = kvalues_iq4nl[xi0];
+ const float v1 = kvalues_iq4nl[xi1];
+ const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
+ const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
+ sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+ }
+
+ data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
+
+}
+#endif
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+ if (gl_LocalInvocationIndex.x != 0) {
+ return;
+ }
+#endif
+
+ const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ uint dst_idx = dst_idx_quant(idx, QUANT_K);
+ uint src_idx = get_aoffset() + src0_idx(idx);
+
+ quantize(dst_idx, src_idx);
+}