summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py213
1 files changed, 88 insertions, 125 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index 67981a75..4abb0383 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -157,19 +157,10 @@ struct block_q6_K
# Dequant functions
shader_f16_dequant_func = """
-#define DEQUANT_FUNC f16vec2 v = f16vec2(data_a[ib + 0], data_a[ib + 1]);
-"""
-shader_f16_dequant_func_compat = """
#define DEQUANT_FUNC vec2 v = vec2(data_a[ib + 0], data_a[ib + 1]);
"""
shader_q4_0_dequant_func = """
-#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
-const uint8_t vui = data_a[ib].qs[iqs]; \
-f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \
-v = (v - 8.0hf)*d;
-"""
-shader_q4_0_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const uint vui = uint(data_a[ib].qs[iqs]); \
vec2 v = vec2(vui & 0xF, vui >> 4); \
@@ -177,13 +168,6 @@ v = (v - 8.0f)*d;
"""
shader_q4_1_dequant_func = """
-#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
-const float16_t m = data_a[ib].m; \
-const uint8_t vui = data_a[ib].qs[iqs]; \
-f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \
-v = v*d + m;
-"""
-shader_q4_1_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const float m = float(data_a[ib].m); \
const uint vui = uint(data_a[ib].qs[iqs]); \
@@ -192,14 +176,6 @@ v = v*d + m;
"""
shader_q5_0_dequant_func = """
-#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
-const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \
-const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
-const uint8_t vui = data_a[ib].qs[iqs]; \
-f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
-v = (v - 16.0hf) * d;
-"""
-shader_q5_0_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
@@ -209,14 +185,6 @@ v = (v - 16.0f) * d;
"""
shader_q5_1_dequant_func = """
-#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
-const float16_t m = data_a[ib].m; \
-const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
-const uint8_t vui = data_a[ib].qs[iqs]; \
-f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
-v = v*d + m;
-"""
-shader_q5_1_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const float m = float(data_a[ib].m); \
const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
@@ -226,11 +194,6 @@ v = v*d + m;
"""
shader_q8_0_dequant_func = """
-#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \
-f16vec2 v = f16vec2(data_a[ib].qs[iqs], data_a[ib].qs[iqs + 1]); \
-v = v * d;
-"""
-shader_q8_0_dequant_func_compat = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \
v = v * d;
@@ -2110,7 +2073,7 @@ lock = asyncio.Lock()
shader_fnames = []
-async def string_to_spv(name, code, defines, fp16):
+async def string_to_spv(name, code, defines, fp16=True):
f = NamedTemporaryFile(mode="w", delete=False)
f.write(code)
f.flush()
@@ -2200,64 +2163,6 @@ async def main():
tasks.append(string_to_spv("matmul_f16_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
- # Build dequant shaders
- tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}, fp16))
-
- for i in range(0, VK_NUM_TYPES):
- stream.clear()
-
- stream.extend((dequant_head, shader_int8_ext, shader_float_type))
-
- if i == GGML_TYPE_F16:
- stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, dequant_body))
- elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat if not fp16 else shader_q4_0_dequant_func, dequant_body))
- elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat if not fp16 else shader_q4_1_dequant_func, dequant_body))
- elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat if not fp16 else shader_q5_0_dequant_func, dequant_body))
- elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat if not fp16 else shader_q5_1_dequant_func, dequant_body))
- elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat if not fp16 else shader_q8_0_dequant_func, dequant_body))
- elif i == GGML_TYPE_Q2_K:
- stream.extend((shader_q2_K_defines, dequant_q2_K_body))
- elif i == GGML_TYPE_Q3_K:
- stream.extend((shader_q3_K_defines, dequant_q3_K_body))
- elif i == GGML_TYPE_Q4_K:
- stream.extend((shader_q4_K_defines, dequant_q4_K_body))
- elif i == GGML_TYPE_Q5_K:
- stream.extend((shader_q5_K_defines, dequant_q5_K_body))
- elif i == GGML_TYPE_Q6_K:
- stream.extend((shader_q6_K_defines, dequant_q6_K_body))
- else:
- continue
-
- tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}, fp16))
-
- # get_rows
- for i in range(0, VK_NUM_TYPES):
- stream.clear()
- stream.extend((generic_head, shader_int8_ext, shader_float_type))
-
- if i == GGML_TYPE_F16:
- stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, get_rows_body))
- elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat if not fp16 else shader_q4_0_dequant_func, get_rows_body))
- elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat if not fp16 else shader_q4_1_dequant_func, get_rows_body))
- elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat if not fp16 else shader_q5_0_dequant_func, get_rows_body))
- elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat if not fp16 else shader_q5_1_dequant_func, get_rows_body))
- elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat if not fp16 else shader_q8_0_dequant_func, get_rows_body))
- else:
- continue
-
- tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"}, fp16))
- tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}, fp16))
-
# Shaders where precision is needed, so no fp16 version
# mul mat vec
@@ -2266,17 +2171,17 @@ async def main():
stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
if i == GGML_TYPE_F16:
- stream.extend((shader_f16_defines, shader_f16_dequant_func_compat, mul_mat_vec_body))
+ stream.extend((shader_f16_defines, shader_f16_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat, mul_mat_vec_body))
+ stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat, mul_mat_vec_body))
+ stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat, mul_mat_vec_body))
+ stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat, mul_mat_vec_body))
+ stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat, mul_mat_vec_body))
+ stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
elif i == GGML_TYPE_Q3_K:
@@ -2290,43 +2195,101 @@ async def main():
else:
continue
- tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}, fp16))
+ tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
+
+ # Dequant shaders
+ for i in range(0, VK_NUM_TYPES):
+ stream.clear()
+
+ stream.extend((dequant_head, shader_int8_ext, shader_f32))
+
+ if i == GGML_TYPE_F16:
+ stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
+ elif i == GGML_TYPE_Q4_0:
+ stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body))
+ elif i == GGML_TYPE_Q4_1:
+ stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
+ elif i == GGML_TYPE_Q5_0:
+ stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body))
+ elif i == GGML_TYPE_Q5_1:
+ stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body))
+ elif i == GGML_TYPE_Q8_0:
+ stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body))
+ elif i == GGML_TYPE_Q2_K:
+ stream.extend((shader_q2_K_defines, dequant_q2_K_body))
+ elif i == GGML_TYPE_Q3_K:
+ stream.extend((shader_q3_K_defines, dequant_q3_K_body))
+ elif i == GGML_TYPE_Q4_K:
+ stream.extend((shader_q4_K_defines, dequant_q4_K_body))
+ elif i == GGML_TYPE_Q5_K:
+ stream.extend((shader_q5_K_defines, dequant_q5_K_body))
+ elif i == GGML_TYPE_Q6_K:
+ stream.extend((shader_q6_K_defines, dequant_q6_K_body))
+ else:
+ continue
+
+ tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
+
+ tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}))
+
+ # get_rows
+ for i in range(0, VK_NUM_TYPES):
+ stream.clear()
+ stream.extend((generic_head, shader_int8_ext, shader_f32))
+
+ if i == GGML_TYPE_F16:
+ stream.extend((shader_f16_defines, shader_f16_dequant_func, get_rows_body))
+ elif i == GGML_TYPE_Q4_0:
+ stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, get_rows_body))
+ elif i == GGML_TYPE_Q4_1:
+ stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, get_rows_body))
+ elif i == GGML_TYPE_Q5_0:
+ stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, get_rows_body))
+ elif i == GGML_TYPE_Q5_1:
+ stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, get_rows_body))
+ elif i == GGML_TYPE_Q8_0:
+ stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, get_rows_body))
+ else:
+ continue
+
+ tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"}))
+ tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
# Norms
- tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}, True))
- tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}, True))
+ tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
+ tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
- tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}, True))
- tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
+ tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, True))
+ tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}, True))
+ tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
- tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}, True))
- tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}, True))
+ tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
await asyncio.gather(*tasks)