diff options
Diffstat (limited to 'kompute-shaders/rope_common.comp')
-rw-r--r-- | kompute-shaders/rope_common.comp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/kompute-shaders/rope_common.comp b/kompute-shaders/rope_common.comp index 57ba6597..7b9394cb 100644 --- a/kompute-shaders/rope_common.comp +++ b/kompute-shaders/rope_common.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter { uint outOff; int n_dims; int mode; - int n_orig_ctx; + int n_ctx_orig; float freq_base; float freq_scale; float ext_factor; @@ -54,14 +54,14 @@ void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base)); +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base)); } void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2] ) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } |