summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
author0cc4m <picard12@live.de>2024-06-11 21:20:29 +0200
committerGitHub <noreply@github.com>2024-06-11 21:20:29 +0200
commitef52d1d16afc695d798396cdd13594ea5e45a9dd (patch)
treed6ef7a296ce1df22eb803650da442a1c5ca55340 /ggml_vk_generate_shaders.py
parent14f83526cd27f638c856ea6eff08110b9860eb2a (diff)
Update Vulkan RoPE implementation (#7818)
* Update Vulkan RoPE implementation * Return nullptr on alloc_buffer when allocation fails, instead of throwing an exception Minor fixes * Fix segfault when running out of VRAM Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py66
1 files changed, 37 insertions, 29 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index a905f570..400a63f5 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -2400,7 +2400,7 @@ void main() {
"""
# ROPE
-rope_src = """
+rope_norm_src = """
#version 450
#extension GL_EXT_shader_16bit_storage : require
@@ -2408,17 +2408,21 @@ rope_src = """
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 1) readonly buffer Y {int data_pos[];};
+layout (binding = 2) readonly buffer Z {float data_ff[];};
+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
+ uint n_dims;
float freq_scale;
uint p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
- float corr_dims[4];
+ float corr_dims[2];
+ float theta_scale;
+ uint has_ff;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -2450,14 +2454,24 @@ void main() {
return;
}
+ if (col >= p.n_dims) {
+ const uint i = row*p.ncols + col;
+
+ data_d[i + 0] = data_a[i + 0];
+ data_d[i + 1] = data_a[i + 1];
+
+ return;
+ }
+
const uint i = row*p.ncols + col;
const uint i2 = row/p.p_delta_rows;
- const int pos = data_b[i2];
- const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols);
+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
float cos_theta, sin_theta;
- rope_yarn(theta_base, col, cos_theta, sin_theta);
+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + 1]);
@@ -2475,22 +2489,21 @@ rope_neox_src = """
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) readonly buffer Z {float data_freq_factors[];};
+layout (binding = 1) readonly buffer Y {int data_pos[];};
+layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
- uint ndims;
+ uint n_dims;
float freq_scale;
uint p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
- float corr_dims[4];
+ float corr_dims[2];
float theta_scale;
- float inv_ndims;
- uint has_freq_facs;
+ uint has_ff;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -2522,11 +2535,8 @@ void main() {
return;
}
- const uint ib = col / p.ndims;
- const uint ic = col % p.ndims;
-
- if (ib > 0) {
- const uint i = row*p.ncols + ib*p.ndims + ic;
+ if (col >= p.n_dims) {
+ const uint i = row*p.ncols + col;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@@ -2534,29 +2544,27 @@ void main() {
return;
}
- const uint i = row*p.ncols + ib*p.ndims + ic/2;
+ const uint i = row*p.ncols + col/2;
const uint i2 = row/p.p_delta_rows;
- const int pos = data_b[i2];
- const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
- const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
float cos_theta, sin_theta;
- rope_yarn(theta_base, ic, cos_theta, sin_theta);
+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + p.ndims/2]);
+ const float x1 = float(data_a[i + p.n_dims/2]);
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+ data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
"""
argsort_src = """
#version 450
-#extension GL_EXT_shader_16bit_storage : require
-
#define BLOCK_SIZE 1024
#define ASC 0
@@ -3039,8 +3047,8 @@ async def main():
tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"}))
- tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
+ tasks.append(string_to_spv("rope_norm_f32", rope_norm_src, {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("rope_norm_f16", rope_norm_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))