diff options
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r-- | ggml_vk_generate_shaders.py | 281 |
1 files changed, 235 insertions, 46 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 4a6f5e32..5dd70096 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -18,6 +18,12 @@ shader_int8_ext = """ """ # Type-specific defines +shader_f32_defines = """ +#define QUANT_K 1 +#define QUANT_R 1 + +#define A_TYPE float +""" shader_f16_defines = """ #define QUANT_K 1 #define QUANT_R 1 @@ -157,8 +163,8 @@ struct block_q6_K """ # Dequant functions -shader_f16_dequant_func = """ -#define DEQUANT_FUNC vec2 v = vec2(data_a[ib + 0], data_a[ib + 1]); +shader_float_dequant_func = """ +#define DEQUANT_FUNC vec2 v = vec2(ib, ib); // data_a[ib], data_a[ib + 1]); """ shader_q4_0_dequant_func = """ @@ -410,6 +416,133 @@ mulmat_load_q8_0 = """ buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);""" + +mulmat_load_q2_K = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);""" + +mulmat_load_q3_K = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : + (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); + buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));""" + +mulmat_load_q4_K = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + uint8_t sc; + uint8_t mbyte; + if (is < 4) { + sc = uint8_t(data_a[ib].scales[is ] & 63); + mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); + } else { + sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); + mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); + } + const float d = loadd.x * sc; + const float m = loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m); + buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);""" + +mulmat_load_q5_K = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + uint8_t sc; + uint8_t mbyte; + if (is < 4) { + sc = uint8_t(data_a[ib].scales[is ] & 63); + mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); + } else { + sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); + mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); + } + const float d = loadd.x * sc; + const float m = loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m); + buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);""" + +mulmat_load_q6_K = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));""" + mulmat_body2 = """ } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { @@ -1611,8 +1744,9 @@ layout (push_constant) uniform parameter uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; uint d_offset; float param1; float param2; -} p; +} p;""" +generic_unary_op_funcs = """ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -1636,14 +1770,17 @@ uint dst_idx(uint idx) { const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; -} +}""" +generic_unary_op_main = """ void main() { if (gl_GlobalInvocationID.x >= p.ne) { return; } """ +generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_funcs}\n{generic_unary_op_main}" + generic_binary_op_head = """#version 450 #extension GL_EXT_shader_16bit_storage : require @@ -1655,13 +1792,14 @@ layout (push_constant) uniform parameter uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; uint d_offset; - uint param1; uint param2; -} p; + float param1; float param2; +} p;""" +generic_binary_op_funcs = """ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; uint src0_idx(uint idx) { @@ -1693,14 +1831,17 @@ uint dst_idx(uint idx) { const uint i21 = (idx - i23_offset - i22_offset) / p.ne20; const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20; return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20; -} +}""" +generic_binary_op_main = """ void main() { if (gl_GlobalInvocationID.x >= p.ne) { return; } """ +generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_funcs}\n{generic_binary_op_main}" + # MUL F32 mul_body = """ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); @@ -1745,39 +1886,55 @@ cpy_f16_f16_end = """ """ # GET_ROWS -get_rows_body = """ -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_8bit_storage : require +get_rows_float_body = """ +void main() { + const uint i00 = gl_GlobalInvocationID.x; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + if (i00 >= p.ne00) { + return; + } -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); +#else + data_d[d_offset + i00] = data_a[a_offset + i00]; +#endif +} +""" + +get_rows_body = """ void main() { - const uint col = int(gl_GlobalInvocationID.x) * 2; - const uint row = int(gl_GlobalInvocationID.y); + const uint i00 = (gl_GlobalInvocationID.x)*2; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; - if (col >= p.KY) { + if (i00 >= p.ne00) { return; } - const uint r = uint(data_b[row]); + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; - // copy data_a[r*p.KY + col] to dst[row*p.KX + col] - const uint xi = r*p.KY + col; - const uint di = row*p.KY + col; + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; - const uint ib = xi/QUANT_K; // block index - const uint iqs = (xi%QUANT_K)/QUANT_R; // quant index - const uint iybs = di - di%QUANT_K; // y block start index + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; DEQUANT_FUNC - dst[iybs + iqs + 0] = D_TYPE(v.x); - dst[iybs + iqs + y_offset] = D_TYPE(v.y); + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); } """ @@ -2418,6 +2575,31 @@ async def main(): tasks.append(string_to_spv("matmul_q8_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_q8_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_q2_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q2_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_q3_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q3_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_q4_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q4_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_q5_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q5_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_q6_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q6_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + # Shaders where precision is needed, so no fp16 version # mul mat vec @@ -2426,7 +2608,7 @@ async def main(): stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32)) if i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, shader_f16_dequant_func, mul_mat_vec_body)) + stream.extend((shader_f16_defines, shader_float_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q4_0: stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q4_1: @@ -2488,25 +2670,32 @@ async def main(): # get_rows for i in range(0, VK_NUM_TYPES): stream.clear() - stream.extend((generic_head, shader_int8_ext, shader_f32)) + stream.extend((generic_binary_op_head, shader_int8_ext, shader_f32)) + optimization_workaround = False - if i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, shader_f16_dequant_func, get_rows_body)) + if i == GGML_TYPE_F32: + stream.extend((shader_f32_defines, generic_binary_op_funcs, get_rows_float_body)) + elif i == GGML_TYPE_F16: + stream.extend((shader_f16_defines, generic_binary_op_funcs, get_rows_float_body)) + optimization_workaround = True elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, get_rows_body)) + stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, get_rows_body)) + stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, get_rows_body)) + stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, get_rows_body)) + stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, get_rows_body)) + stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body)) else: continue - tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"})) + if optimization_workaround: + tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"})) + else: + tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float"})) tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"})) @@ -2515,20 +2704,20 @@ async def main(): tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_head}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_head}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_head}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_combined}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv("add_f32", f"{generic_binary_op_head}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {})) - tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_head}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_head}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_head}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_head}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_combined}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) |