diff options
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r-- | ggml_vk_generate_shaders.py | 66 |
1 files changed, 37 insertions, 29 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index a905f570..400a63f5 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -2400,7 +2400,7 @@ void main() { """ # ROPE -rope_src = """ +rope_norm_src = """ #version 450 #extension GL_EXT_shader_16bit_storage : require @@ -2408,17 +2408,21 @@ rope_src = """ layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint n_dims; float freq_scale; uint p_delta_rows; float freq_base; float ext_factor; float attn_factor; - float corr_dims[4]; + float corr_dims[2]; + float theta_scale; + uint has_ff; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { @@ -2450,14 +2454,24 @@ void main() { return; } + if (col >= p.n_dims) { + const uint i = row*p.ncols + col; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + const uint i = row*p.ncols + col; const uint i2 = row/p.p_delta_rows; - const int pos = data_b[i2]; - const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols); + const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); const float x0 = float(data_a[i + 0]); const float x1 = float(data_a[i + 1]); @@ -2475,22 +2489,21 @@ rope_neox_src = """ layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_b[];}; -layout (binding = 2) readonly buffer Z {float data_freq_factors[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; layout (push_constant) uniform parameter { uint ncols; - uint ndims; + uint n_dims; float freq_scale; uint p_delta_rows; float freq_base; float ext_factor; float attn_factor; - float corr_dims[4]; + float corr_dims[2]; float theta_scale; - float inv_ndims; - uint has_freq_facs; + uint has_ff; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { @@ -2522,11 +2535,8 @@ void main() { return; } - const uint ib = col / p.ndims; - const uint ic = col % p.ndims; - - if (ib > 0) { - const uint i = row*p.ncols + ib*p.ndims + ic; + if (col >= p.n_dims) { + const uint i = row*p.ncols + col; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -2534,29 +2544,27 @@ void main() { return; } - const uint i = row*p.ncols + ib*p.ndims + ic/2; + const uint i = row*p.ncols + col/2; const uint i2 = row/p.p_delta_rows; - const int pos = data_b[i2]; - const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f; - const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor; + const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base, ic, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + p.ndims/2]); + const float x1 = float(data_a[i + p.n_dims/2]); data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); } """ argsort_src = """ #version 450 -#extension GL_EXT_shader_16bit_storage : require - #define BLOCK_SIZE 1024 #define ASC 0 @@ -3039,8 +3047,8 @@ async def main(): tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"})) - tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv("rope_norm_f32", rope_norm_src, {"A_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("rope_norm_f16", rope_norm_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) |