summaryrefslogtreecommitdiff
path: root/kompute-shaders/rope_common.comp
diff options
context:
space:
mode:
Diffstat (limited to 'kompute-shaders/rope_common.comp')
-rw-r--r--kompute-shaders/rope_common.comp12
1 files changed, 6 insertions, 6 deletions
diff --git a/kompute-shaders/rope_common.comp b/kompute-shaders/rope_common.comp
index 57ba6597..7b9394cb 100644
--- a/kompute-shaders/rope_common.comp
+++ b/kompute-shaders/rope_common.comp
@@ -9,7 +9,7 @@ layout (push_constant) uniform parameter {
uint outOff;
int n_dims;
int mode;
- int n_orig_ctx;
+ int n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
@@ -54,14 +54,14 @@ void rope_yarn(
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
- return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
+float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+ return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
}
void rope_yarn_corr_dims(
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
) {
// start and end correction dims
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}