diff options
Diffstat (limited to 'ggml/src/vulkan-shaders')
29 files changed, 900 insertions, 165 deletions
diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/vulkan-shaders/CMakeLists.txt index a22ea817..e1f613fb 100644 --- a/ggml/src/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/vulkan-shaders/CMakeLists.txt @@ -27,5 +27,5 @@ endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) diff --git a/ggml/src/vulkan-shaders/conv2d_dw.comp b/ggml/src/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 00000000..938c74da --- /dev/null +++ b/ggml/src/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ggml/src/vulkan-shaders/copy_to_quant.comp b/ggml/src/vulkan-shaders/copy_to_quant.comp index 9c76437d..e06547e4 100644 --- a/ggml/src/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/vulkan-shaders/copy_to_quant.comp @@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi #endif // RTE16 #include "types.comp" -#include "generic_unary_head.comp" -#if defined(DATA_A_IQ4_NL) -// 16 invocations needed for init_iq4nl_shmem -layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#if defined(SET_ROWS) && QUANT_K == 1 +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 512; #else -layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 32; #endif layout (binding = 0) readonly buffer S {float data_s[];}; + +#if defined(SET_ROWS) +#include "generic_binary_head.comp" +layout (binding = 1) readonly buffer C {uvec2 data_i[];}; +layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; +#else +#include "generic_unary_head.comp" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; +#endif #if defined(DATA_A_Q4_0) void quantize(uint dst_idx, uint src_idx) @@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_F32) || defined(DATA_A_F16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(data_s[src_idx]); +} +#endif + +#if defined(DATA_A_BF16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx])); +} +#endif + +#if defined(SET_ROWS) + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); - if (gl_LocalInvocationIndex.x != 0) { +#endif + + const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { return; } + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + uint i12 = fastmod(i03, p.ne12); + uint i11 = fastmod(i02, p.ne11); + uint i10 = i01; + + uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x; + + uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset(); + uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset(); + + quantize(dst_idx, src0_idx); +} + +#else + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif - const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K; if (idx >= p.ne) { return; @@ -240,3 +289,5 @@ void main() { quantize(dst_idx, src_idx); } + +#endif diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp index 454b3411..45c6e773 100644 --- a/ggml/src/vulkan-shaders/flash_attn.comp +++ b/ggml/src/vulkan-shaders/flash_attn.comp @@ -100,6 +100,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -145,13 +149,13 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp index 1d3e6387..7defe72b 100644 --- a/ggml/src/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/vulkan-shaders/flash_attn_base.comp @@ -24,6 +24,8 @@ layout (push_constant) uniform parameter { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -34,14 +36,12 @@ layout (push_constant) uniform parameter { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -50,6 +50,9 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) @@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i { const uint32_t h = iq2 + (r % p.gqa_ratio); - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); return ACC_TYPE(pow(base, ACC_TYPE(exph))); } diff --git a/ggml/src/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/vulkan-shaders/flash_attn_cm1.comp index ad7594fe..486735fe 100644 --- a/ggml/src/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/vulkan-shaders/flash_attn_cm1.comp @@ -125,6 +125,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -178,12 +182,12 @@ void main() { barrier(); } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } barrier(); diff --git a/ggml/src/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/vulkan-shaders/flash_attn_cm2.comp index 91caa184..274f48fc 100644 --- a/ggml/src/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/vulkan-shaders/flash_attn_cm2.comp @@ -130,6 +130,11 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -148,14 +153,14 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; - coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); } diff --git a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp index a7e39568..0a17a9df 100644 --- a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 32 +layout(constant_id = 0) const uint BLOCK_SIZE = 32; -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 1) writeonly buffer D {float data_d[];}; @@ -12,48 +12,80 @@ layout (binding = 1) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; uint N; + uint ne3; uint k_num; } p; +shared float tmpsh[BLOCK_SIZE]; + void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; uint D = p.D; uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * k_num + n; - uint m_offset = D * N * k_num + N + n; + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; uint lm_stride = N * 2; // Compute the max m value for the row float m_max = -1.0/0.0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; m_max = max(m_max, m); } + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + // Compute L based on m_max float L = 0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float l = data_a[l_offset + k * lm_stride]; - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; L += exp(m - m_max) * l; } + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + L = 1.0 / L; + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; // Scale and sum the O contributions based on m_max and store the result to memory - for (uint d = tid; d < D; d += BLOCK_SIZE) { + if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * k + D * n + d; + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } O *= L; - data_d[D * n + d] = O; + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ggml/src/vulkan-shaders/geglu.comp b/ggml/src/vulkan-shaders/geglu.comp new file mode 100644 index 00000000..f4268ed2 --- /dev/null +++ b/ggml/src/vulkan-shaders/geglu.comp @@ -0,0 +1,13 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/vulkan-shaders/geglu_erf.comp b/ggml/src/vulkan-shaders/geglu_erf.comp new file mode 100644 index 00000000..cbd4cb36 --- /dev/null +++ b/ggml/src/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.comp" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/vulkan-shaders/geglu_quick.comp b/ggml/src/vulkan-shaders/geglu_quick.comp new file mode 100644 index 00000000..3a2a6897 --- /dev/null +++ b/ggml/src/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/vulkan-shaders/gelu_erf.comp b/ggml/src/vulkan-shaders/gelu_erf.comp new file mode 100644 index 00000000..5fd5a5e7 --- /dev/null +++ b/ggml/src/vulkan-shaders/gelu_erf.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation + // ref: https://www.johndcook.com/blog/python_erf/ + const float p_erf = 0.3275911f; + const float a1_erf = 0.254829592f; + const float a2_erf = -0.284496736f; + const float a3_erf = 1.421413741f; + const float a4_erf = -1.453152027f; + const float a5_erf = 1.061405429f; + + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float a = float(data_a[i]); + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx)); +} diff --git a/ggml/src/vulkan-shaders/glu_head.comp b/ggml/src/vulkan-shaders/glu_head.comp new file mode 100644 index 00000000..41a29889 --- /dev/null +++ b/ggml/src/vulkan-shaders/glu_head.comp @@ -0,0 +1,15 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + uint N; + uint ne00; + uint ne20; + uint mode; +} p; diff --git a/ggml/src/vulkan-shaders/glu_main.comp b/ggml/src/vulkan-shaders/glu_main.comp new file mode 100644 index 00000000..85cf65a9 --- /dev/null +++ b/ggml/src/vulkan-shaders/glu_main.comp @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} diff --git a/ggml/src/vulkan-shaders/l2_norm.comp b/ggml/src/vulkan-shaders/l2_norm.comp new file mode 100644 index 00000000..deba8c39 --- /dev/null +++ b/ggml/src/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp index 26163b16..f4815499 100644 --- a/ggml/src/vulkan-shaders/mul_mm.comp +++ b/ggml/src/vulkan-shaders/mul_mm.comp @@ -18,6 +18,7 @@ #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; +uint _ne1; +#ifdef COOPMAT +shared uint _ne1_sh; +#endif #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -172,7 +177,47 @@ void main() { const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; +#ifdef COOPMAT + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec2(ii0, ii1); + } + _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; +#else + _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii0 = 0; ii0 < p.nei0; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { @@ -183,6 +228,7 @@ void main() { } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; @@ -500,10 +546,9 @@ void main() { const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -512,22 +557,16 @@ void main() { const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -538,21 +577,17 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -562,63 +597,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -631,33 +684,36 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/vulkan-shaders/mul_mm_cm2.comp index 91846575..29e4b5c9 100644 --- a/ggml/src/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/vulkan-shaders/mul_mm_cm2.comp @@ -162,17 +162,32 @@ void main() { _ne1 = 0; uint num_elements = p.nei1 * p.nei0; - for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; bool in_range = i < num_elements; - uint ii0 = i % p.nei0; uint ii1 = i / p.nei0; - uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); uint idx = subgroupBallotExclusiveBitCount(ballot); if (in_range && id == expert_idx) { row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); } _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; } _ne1_sh = _ne1; } @@ -414,17 +429,31 @@ void main() { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; - coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; + coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif - sum = coopMatMulAdd(mat_a, mat_b, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; + coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } } // Convert from ACC_TYPE to D_TYPE diff --git a/ggml/src/vulkan-shaders/reglu.comp b/ggml/src/vulkan-shaders/reglu.comp new file mode 100644 index 00000000..0073d8f7 --- /dev/null +++ b/ggml/src/vulkan-shaders/reglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return max(a, 0.0f) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/vulkan-shaders/rms_norm.comp index deb8ee99..6428ca7b 100644 --- a/ggml/src/vulkan-shaders/rms_norm.comp +++ b/ggml/src/vulkan-shaders/rms_norm.comp @@ -1,11 +1,13 @@ #version 450 -#include "generic_unary_head.comp" +#include "generic_binary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 +layout (constant_id = 1) const bool do_multiply = false; + layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sum[BLOCK_SIZE]; @@ -25,6 +27,7 @@ void main() { const uint stride_sample = p.nb03; uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp @@ -46,7 +49,13 @@ void main() { const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + if (do_multiply) { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } } } diff --git a/ggml/src/vulkan-shaders/roll.comp b/ggml/src/vulkan-shaders/roll.comp new file mode 100644 index 00000000..b9abe8de --- /dev/null +++ b/ggml/src/vulkan-shaders/roll.comp @@ -0,0 +1,46 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint wrap_idx(int i, uint ne) { + if (i < 0) { + return i + ne; + } else if (i >= ne) { + return i - ne; + } + return i; +} + +void main() { + const uint idx = get_idx(); + if (idx >= p.ne) { + return; + } + + const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L); + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint p1 = floatBitsToUint(p.param1); + const uint p2 = floatBitsToUint(p.param2); + const int s0 = int(p1 >> 16) - 0x8000; + const int s1 = int(p1 & 0xFFFF) - 0x8000; + const int s2 = int(p2 >> 16) - 0x8000; + const int s3 = int(p2 & 0xFFFF) - 0x8000; + + const uint i00 = wrap_idx(int(i0) - s0, p.ne10); + const uint i01 = wrap_idx(int(i1) - s1, p.ne11); + const uint i02 = wrap_idx(int(i2) - s2, p.ne12); + const uint i03 = wrap_idx(int(i3) - s3, p.ne13); + + const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; + const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10; + + data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]); +} diff --git a/ggml/src/vulkan-shaders/rope_multi.comp b/ggml/src/vulkan-shaders/rope_multi.comp index 4f5b1a0e..5808710c 100644 --- a/ggml/src/vulkan-shaders/rope_multi.comp +++ b/ggml/src/vulkan-shaders/rope_multi.comp @@ -14,21 +14,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; const int sec_w = p.sections[1] + p.sections[0]; const uint sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/vulkan-shaders/rope_neox.comp b/ggml/src/vulkan-shaders/rope_neox.comp index db775c45..366a7b1c 100644 --- a/ggml/src/vulkan-shaders/rope_neox.comp +++ b/ggml/src/vulkan-shaders/rope_neox.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ggml/src/vulkan-shaders/rope_norm.comp b/ggml/src/vulkan-shaders/rope_norm.comp index 4ad35e54..9643bca9 100644 --- a/ggml/src/vulkan-shaders/rope_norm.comp +++ b/ggml/src/vulkan-shaders/rope_norm.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + if (i0 >= p.n_dims) { + data_d[idst + 0] = data_a[ix + 0]; + data_d[idst + 1] = data_a[ix + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp index 4663428d..f10b0a02 100644 --- a/ggml/src/vulkan-shaders/scale.comp +++ b/ggml/src/vulkan-shaders/scale.comp @@ -18,7 +18,7 @@ void main() { continue; } - data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); idx += num_threads; } } diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp index 51fc2dc7..5bcd3b1e 100644 --- a/ggml/src/vulkan-shaders/soft_max.comp +++ b/ggml/src/vulkan-shaders/soft_max.comp @@ -6,6 +6,14 @@ layout (push_constant) uniform parameter { uint KX; uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; float scale; float max_bias; float m0; @@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE]; void soft_max(uint num_iters) { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } if (rowx >= p.nrows_x) { return; @@ -41,7 +57,7 @@ void soft_max(uint num_iters) { // ALiBi if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index + const uint h = (rowx / p.ne01) % p.ne02; // head index const float base = h < p.n_head_log2 ? p.m0 : p.m1; const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; @@ -67,7 +83,7 @@ void soft_max(uint num_iters) { FLOAT_TYPE b = FLOAT_TYPE(0); if (p.KY > 0 && col < p.KX) { - b = data_b[rowy * p.KX + col]; + b = data_b[rowy_start + col]; } FLOAT_TYPE v = a * p.scale + slope * b; @@ -111,7 +127,7 @@ void soft_max(uint num_iters) { if (idx < DATA_CACHE_SIZE) { val = exp(data_cache[idx] - max_val); } else { - val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); } sum += val; if (idx < DATA_CACHE_SIZE) { diff --git a/ggml/src/vulkan-shaders/swiglu.comp b/ggml/src/vulkan-shaders/swiglu.comp new file mode 100644 index 00000000..a28e7c6c --- /dev/null +++ b/ggml/src/vulkan-shaders/swiglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/vulkan-shaders/upscale.comp b/ggml/src/vulkan-shaders/upscale.comp index 6f607380..74771def 100644 --- a/ggml/src/vulkan-shaders/upscale.comp +++ b/ggml/src/vulkan-shaders/upscale.comp @@ -3,6 +3,7 @@ layout (push_constant) uniform parameter { uint ne; uint a_offset; uint d_offset; + uint ne00; uint ne01; uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; @@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag +#define NEAREST 0 +#define BILINEAR 1 +#define ALIGN_CORNERS (1 << 8) + +layout (constant_id = 0) const uint scale_mode = 0; + +float fetch_nearest(uint i10, uint i11, uint i12, uint i13) { + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]; +} + +float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02; + + const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00]; + const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00]; + const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00]; + const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00]; + + return + v00 * (1.0-d.x) * (1.0-d.y) + + v01 * d.x * (1.0-d.y) + + v10 * (1.0-d.x) * d.y + + v11 * d.x * d.y; +} + +float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { + const ivec2 ne0 = ivec2(p.ne00, p.ne01); + + const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = max(ivec2(c0f), 0); + const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1); + + return fetch_bilinear(c0, c1, d, i12, i13); +} + +float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { + const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = ivec2(c0f); + const ivec2 c1 = c0 + 1; + + return fetch_bilinear(c0, c1, d, i12, i13); +} + void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -27,10 +83,18 @@ void main() { const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; - const uint i00 = uint(i10 / p.sf0); - const uint i01 = uint(i11 / p.sf1); - const uint i02 = uint(i12 / p.sf2); - const uint i03 = uint(i13 / p.sf3); + float result; + switch (scale_mode) { + case NEAREST: + result = fetch_nearest(i10, i11, i12, i13); + break; + case BILINEAR: + result = interpolate_bilinear(i10, i11, i12, i13); + break; + case BILINEAR | ALIGN_CORNERS: + result = interpolate_bilinear_align_corners(i10, i11, i12, i13); + break; + } - data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); + data_d[p.d_offset + idx] = D_TYPE(result); } diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index 65dd82de..293aa644 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) load_vec_quant = "4"; if (tname == "bf16") { @@ -497,9 +497,9 @@ void process_shaders() { // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); @@ -518,6 +518,11 @@ void process_shaders() { string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } + for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + } + auto get_type_str = [](bool f16) { return f16 ? "float16_t" : "float"; }; @@ -572,17 +577,10 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -594,6 +592,17 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -637,8 +646,30 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + + string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + // ============================== ik_llama.cpp + // + string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + // + // ============================== end ik_llama.cpp + for (auto &c : compiles) { c.wait(); } diff --git a/ggml/src/vulkan-shaders/wkv7.comp b/ggml/src/vulkan-shaders/wkv7.comp new file mode 100644 index 00000000..88c1c02b --- /dev/null +++ b/ggml/src/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} |