summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu75
1 files changed, 56 insertions, 19 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index c0fb9fb6..8ab29bb2 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -3887,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {
- const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+ const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
return;
}
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col;
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
@@ -3965,8 +3965,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
}
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
- const int col = blockDim.x*blockIdx.x + threadIdx.x;
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
+ const int col = blockDim.y*blockIdx.y + threadIdx.y;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
if (col >= ncols) {
return;
@@ -3982,9 +3982,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
- const int block_size = blockDim.x;
- const int tid = threadIdx.x;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+ const int block_size = blockDim.y;
+ const int tid = threadIdx.y;
float tmp = 0.0;
@@ -4776,9 +4776,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0);
- const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+ const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
- const dim3 block_nums(num_blocks_x, nrows, 1);
+ const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}
@@ -4800,15 +4800,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
- const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+ const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
- const dim3 block_nums(block_num_x, nrows_x, 1);
+ const dim3 block_nums(nrows_x, block_num_x, 1);
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- const dim3 block_nums(1, nrows_x, 1);
+ const dim3 block_dims(1, WARP_SIZE, 1);
+ const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}
@@ -6313,7 +6313,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
return extra;
}
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
if (scratch && g_scratch_size == 0) {
return;
}
@@ -6322,14 +6322,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
- ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
+ ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
}
}
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
- ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
+ ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
}
tensor->backend = GGML_BACKEND_GPU;
+
+ if (scratch && no_alloc) {
+ return;
+ }
+
struct ggml_tensor_extra_gpu * extra;
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
@@ -6381,16 +6386,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
tensor->extra = extra;
}
+void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
+ if (g_scratch_size == 0) {
+ return;
+ }
+ if (g_scratch_buffer == nullptr) {
+ CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
+ }
+
+ struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
+
+ const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
+ tensor->op == GGML_OP_VIEW;
+
+ if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
+ char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+ size_t view_offset = 0;
+ if (tensor->op == GGML_OP_VIEW) {
+ memcpy(&view_offset, tensor->op_params, sizeof(size_t));
+ }
+ extra->data_device[g_main_device] = src0_ddc + view_offset;
+ } else {
+ extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
+ }
+
+ tensor->extra = extra;
+}
+
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
- ggml_cuda_assign_buffers_impl(tensor, true, false);
+ ggml_cuda_assign_buffers_impl(tensor, true, false, false);
+}
+
+void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
+ ggml_cuda_assign_buffers_impl(tensor, true, false, true);
}
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
- ggml_cuda_assign_buffers_impl(tensor, false, false);
+ ggml_cuda_assign_buffers_impl(tensor, false, false, false);
}
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
- ggml_cuda_assign_buffers_impl(tensor, false, true);
+ ggml_cuda_assign_buffers_impl(tensor, false, true, false);
}
void ggml_cuda_set_main_device(int main_device) {