diff options
Diffstat (limited to 'kompute-shaders')
-rw-r--r-- | kompute-shaders/op_rope_f16.comp | 2 | ||||
-rw-r--r-- | kompute-shaders/op_rope_f32.comp | 2 | ||||
-rw-r--r-- | kompute-shaders/rope_common.comp | 12 |
3 files changed, 8 insertions, 8 deletions
diff --git a/kompute-shaders/op_rope_f16.comp b/kompute-shaders/op_rope_f16.comp index b4462258..1a4058b3 100644 --- a/kompute-shaders/op_rope_f16.comp +++ b/kompute-shaders/op_rope_f16.comp @@ -14,7 +14,7 @@ void main() { const bool is_neox = (pcs.mode & 2) != 0; float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); diff --git a/kompute-shaders/op_rope_f32.comp b/kompute-shaders/op_rope_f32.comp index 2c0235d7..65e03827 100644 --- a/kompute-shaders/op_rope_f32.comp +++ b/kompute-shaders/op_rope_f32.comp @@ -14,7 +14,7 @@ void main() { const bool is_neox = (pcs.mode & 2) != 0; float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); diff --git a/kompute-shaders/rope_common.comp b/kompute-shaders/rope_common.comp index 57ba6597..7b9394cb 100644 --- a/kompute-shaders/rope_common.comp +++ b/kompute-shaders/rope_common.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter { uint outOff; int n_dims; int mode; - int n_orig_ctx; + int n_ctx_orig; float freq_base; float freq_scale; float ext_factor; @@ -54,14 +54,14 @@ void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base)); +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base)); } void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2] ) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } |