diff options
Diffstat (limited to 'ggml/src/vulkan-shaders/copy_to_quant.comp')
-rw-r--r-- | ggml/src/vulkan-shaders/copy_to_quant.comp | 65 |
1 files changed, 58 insertions, 7 deletions
diff --git a/ggml/src/vulkan-shaders/copy_to_quant.comp b/ggml/src/vulkan-shaders/copy_to_quant.comp index 9c76437d..e06547e4 100644 --- a/ggml/src/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/vulkan-shaders/copy_to_quant.comp @@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi #endif // RTE16 #include "types.comp" -#include "generic_unary_head.comp" -#if defined(DATA_A_IQ4_NL) -// 16 invocations needed for init_iq4nl_shmem -layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#if defined(SET_ROWS) && QUANT_K == 1 +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 512; #else -layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 32; #endif layout (binding = 0) readonly buffer S {float data_s[];}; + +#if defined(SET_ROWS) +#include "generic_binary_head.comp" +layout (binding = 1) readonly buffer C {uvec2 data_i[];}; +layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; +#else +#include "generic_unary_head.comp" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; +#endif #if defined(DATA_A_Q4_0) void quantize(uint dst_idx, uint src_idx) @@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_F32) || defined(DATA_A_F16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(data_s[src_idx]); +} +#endif + +#if defined(DATA_A_BF16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx])); +} +#endif + +#if defined(SET_ROWS) + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); - if (gl_LocalInvocationIndex.x != 0) { +#endif + + const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { return; } + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + uint i12 = fastmod(i03, p.ne12); + uint i11 = fastmod(i02, p.ne11); + uint i10 = i01; + + uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x; + + uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset(); + uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset(); + + quantize(dst_idx, src0_idx); +} + +#else + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif - const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K; if (idx >= p.ne) { return; @@ -240,3 +289,5 @@ void main() { quantize(dst_idx, src_idx); } + +#endif |