summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/get_rows_quant.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/get_rows_quant.comp')
-rw-r--r--ggml/src/vulkan-shaders/get_rows_quant.comp10
1 files changed, 10 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/get_rows_quant.comp b/ggml/src/vulkan-shaders/get_rows_quant.comp
index 53a9a96f..cfd645a3 100644
--- a/ggml/src/vulkan-shaders/get_rows_quant.comp
+++ b/ggml/src/vulkan-shaders/get_rows_quant.comp
@@ -1,15 +1,23 @@
#version 450
+#extension GL_EXT_control_flow_attributes : enable
+
#include "types.comp"
#include "generic_binary_head.comp"
#include "dequant_funcs.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
const uint i10 = gl_GlobalInvocationID.y;
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
if (i00 >= p.ne00) {
return;
}
@@ -25,6 +33,8 @@ void main() {
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
vec2 v = dequantize(ib, iqs, 0);
+ const vec2 dm = get_dm(ib, 0);
+ v = v * dm.x + dm.y;
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);