summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcebtenzzre <cebtenzzre@gmail.com>2023-11-02 01:49:44 -0400
committerGitHub <noreply@github.com>2023-11-02 07:49:44 +0200
commit2fffa0d61fa10e4b466e78cabcc6a4e16717b580 (patch)
tree1f9f1445335cdbb548fcd7569035f43e3eb68cd4
parent0eb332a10f3f14a3746c391bf80ff5e7bdf29d5d (diff)
cuda : fix RoPE after #2268 (#3897)
-rw-r--r--ggml-cuda.cu6
1 files changed, 3 insertions, 3 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 61cd1747..57a528ed 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4539,7 +4539,7 @@ static __global__ void rope(
const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0;
- const float theta_base = p*powf(freq_base, -col/ncols);
+ const float theta_base = p*powf(freq_base, -float(col)/ncols);
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -4566,8 +4566,8 @@ static __global__ void rope_neox(
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;
- // simplified from `(row * ncols + col) * (-1 / ncols)`
- const float cur_rot = -col/ncols - row;
+ // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
+ const float cur_rot = -float(col)/ncols;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, cur_rot);