summaryrefslogtreecommitdiff
path: root/kompute-shaders
diff options
context:
space:
mode:
Diffstat (limited to 'kompute-shaders')
-rw-r--r--kompute-shaders/op_rope_f16.comp2
-rw-r--r--kompute-shaders/op_rope_f32.comp2
-rw-r--r--kompute-shaders/rope_common.comp12
3 files changed, 8 insertions, 8 deletions
diff --git a/kompute-shaders/op_rope_f16.comp b/kompute-shaders/op_rope_f16.comp
index b4462258..1a4058b3 100644
--- a/kompute-shaders/op_rope_f16.comp
+++ b/kompute-shaders/op_rope_f16.comp
@@ -14,7 +14,7 @@ void main() {
const bool is_neox = (pcs.mode & 2) != 0;
float corr_dims[2];
- rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+ rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
diff --git a/kompute-shaders/op_rope_f32.comp b/kompute-shaders/op_rope_f32.comp
index 2c0235d7..65e03827 100644
--- a/kompute-shaders/op_rope_f32.comp
+++ b/kompute-shaders/op_rope_f32.comp
@@ -14,7 +14,7 @@ void main() {
const bool is_neox = (pcs.mode & 2) != 0;
float corr_dims[2];
- rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+ rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
diff --git a/kompute-shaders/rope_common.comp b/kompute-shaders/rope_common.comp
index 57ba6597..7b9394cb 100644
--- a/kompute-shaders/rope_common.comp
+++ b/kompute-shaders/rope_common.comp
@@ -9,7 +9,7 @@ layout (push_constant) uniform parameter {
uint outOff;
int n_dims;
int mode;
- int n_orig_ctx;
+ int n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
@@ -54,14 +54,14 @@ void rope_yarn(
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
- return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
+float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+ return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
}
void rope_yarn_corr_dims(
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
) {
// start and end correction dims
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}