summaryrefslogtreecommitdiff
path: root/ggml-cuda/softmax.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda/softmax.cu')
-rw-r--r--ggml-cuda/softmax.cu8
1 files changed, 4 insertions, 4 deletions
diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu
index 9bda18e5..fa8f987c 100644
--- a/ggml-cuda/softmax.cu
+++ b/ggml-cuda/softmax.cu
@@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
- float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
+ float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
float max_val = -INFINITY;
@@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
break;
}
- const int ix = rowx*ncols + col;
- const int iy = rowy*ncols + col;
+ const int64_t ix = (int64_t)rowx*ncols + col;
+ const int64_t iy = (int64_t)rowy*ncols + col;
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
@@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
return;
}
- const int idst = rowx*ncols + col;
+ const int64_t idst = (int64_t)rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
}
}