summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r--ggml/src/ggml-cuda.cu44
1 files changed, 36 insertions, 8 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index e0035c7a..7fb67738 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2186,6 +2186,7 @@ struct mmid_row_mapping {
int32_t i2;
};
+template <typename data_t = float>
static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous,
const mmid_row_mapping * __restrict__ row_mapping,
int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
@@ -2194,8 +2195,8 @@ static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_or
const int32_t i11 = row_mapping[i].i1 % ne11;
const int32_t i12 = row_mapping[i].i2;
- float * src_row_contiguous = (float *)(src_contiguous + i*nb11);
- const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12);
+ data_t * src_row_contiguous = (data_t *)(src_contiguous + i*nb11);
+ const data_t * src_row_original = (const data_t *)(src_original + i11*nb11 + i12*nb12);
for (int j = threadIdx.x; j < ne10; j += blockDim.x) {
src_row_contiguous[j] = src_row_original[j];
@@ -2673,6 +2674,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}
} else {
+ //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]);
+ ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
+ bool use_quantized_src1 = false;
+ int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0;
+ if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) {
+ src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
+ src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1);
+ src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
+ src1_quantized.alloc(src1_quantized_size);
+ use_quantized_src1 = true;
+ }
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@@ -2704,7 +2716,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
if (num_src1_rows == 0) continue;
size_t mapping_offset = cum_moe_counts[i02];
- {
+ if (use_quantized_src1) {
+ quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset),
+ src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+ src1_row.data = src1_quantized.get();
+ }
+ else {
dim3 block_dims(std::min((unsigned int)ne10, 768u));
dim3 grid_dims(num_src1_rows);
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
@@ -2719,9 +2737,9 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
- src1_row.nb[1] = nb11;
- src1_row.nb[2] = num_src1_rows*nb11;
- src1_row.nb[3] = num_src1_rows*nb11;
+ src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11;
+ src1_row.nb[2] = num_src1_rows*src1_row.nb[1];
+ src1_row.nb[3] = num_src1_rows*src1_row.nb[1];
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
@@ -2729,11 +2747,21 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
dst_row.nb[3] = num_src1_rows*nb1;
dst_row.data = dst_up_contiguous.get();
- ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
+ if (use_quantized_src1) {
+ ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
+ 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
+ } else {
+ ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
+ }
CUDA_CHECK(cudaGetLastError());
dst_row.data = dst_gate_contiguous.get();
- ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
+ if (use_quantized_src1) {
+ ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
+ 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
+ } else {
+ ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
+ }
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),