summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-07-07 07:23:12 +0200
committerGitHub <noreply@github.com>2025-07-07 07:23:12 +0200
commit4c0b66026619cf51f45249181bf2cc1de8cd6884 (patch)
tree93c1b5474296180dda5eaf302ffa4ff615e4d62f
parent6f3a3ba7e249cd689cb1ab0376e6504fb6cd49e7 (diff)
CUDA: small PP performance improvement for MoE models (#589)
* Trying to implement quantized fmoe - not working yet * This works, but is slower than the non-working version * quantize_mmq_q8_1_id * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda.cu44
-rw-r--r--ggml/src/ggml-cuda/quantize.cu121
-rw-r--r--ggml/src/ggml-cuda/quantize.cuh4
3 files changed, 161 insertions, 8 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index e0035c7a..7fb67738 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2186,6 +2186,7 @@ struct mmid_row_mapping {
int32_t i2;
};
+template <typename data_t = float>
static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous,
const mmid_row_mapping * __restrict__ row_mapping,
int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
@@ -2194,8 +2195,8 @@ static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_or
const int32_t i11 = row_mapping[i].i1 % ne11;
const int32_t i12 = row_mapping[i].i2;
- float * src_row_contiguous = (float *)(src_contiguous + i*nb11);
- const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12);
+ data_t * src_row_contiguous = (data_t *)(src_contiguous + i*nb11);
+ const data_t * src_row_original = (const data_t *)(src_original + i11*nb11 + i12*nb12);
for (int j = threadIdx.x; j < ne10; j += blockDim.x) {
src_row_contiguous[j] = src_row_original[j];
@@ -2673,6 +2674,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}
} else {
+ //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]);
+ ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
+ bool use_quantized_src1 = false;
+ int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0;
+ if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) {
+ src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
+ src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1);
+ src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
+ src1_quantized.alloc(src1_quantized_size);
+ use_quantized_src1 = true;
+ }
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@@ -2704,7 +2716,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
if (num_src1_rows == 0) continue;
size_t mapping_offset = cum_moe_counts[i02];
- {
+ if (use_quantized_src1) {
+ quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset),
+ src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+ src1_row.data = src1_quantized.get();
+ }
+ else {
dim3 block_dims(std::min((unsigned int)ne10, 768u));
dim3 grid_dims(num_src1_rows);
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
@@ -2719,9 +2737,9 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
- src1_row.nb[1] = nb11;
- src1_row.nb[2] = num_src1_rows*nb11;
- src1_row.nb[3] = num_src1_rows*nb11;
+ src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11;
+ src1_row.nb[2] = num_src1_rows*src1_row.nb[1];
+ src1_row.nb[3] = num_src1_rows*src1_row.nb[1];
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
@@ -2729,11 +2747,21 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
dst_row.nb[3] = num_src1_rows*nb1;
dst_row.data = dst_up_contiguous.get();
- ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
+ if (use_quantized_src1) {
+ ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
+ 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
+ } else {
+ ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
+ }
CUDA_CHECK(cudaGetLastError());
dst_row.data = dst_gate_contiguous.get();
- ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
+ if (use_quantized_src1) {
+ ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
+ 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
+ } else {
+ ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
+ }
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu
index 953eb9d9..52d8787d 100644
--- a/ggml/src/ggml-cuda/quantize.cu
+++ b/ggml/src/ggml-cuda/quantize.cu
@@ -166,6 +166,98 @@ static __global__ void quantize_mmq_q8_1(
}
}
+struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+};
+
+template <mmq_q8_1_ds_layout ds_layout>
+static __global__ void quantize_mmq_q8_1_id(
+ const float * __restrict__ x, void * __restrict__ vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+
+ constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+ constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+ const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
+
+ if (ix0 >= kx0_padded) {
+ return;
+ }
+
+ const float4 * x4 = (const float4 *) x;
+
+ const mmid_row_mapping * mapping = (const mmid_row_mapping *)row_mapping;
+ const int64_t ii = mapping[blockIdx.y].i2;
+
+ block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+ const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+ const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
+ const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
+
+ // Load 4 floats per thread and calculate max. abs. value between them:
+ const float4 xi = ix0 < kx0 ? x4[(ii*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+ float amax = fabsf(xi.x);
+ amax = fmaxf(amax, fabsf(xi.y));
+ amax = fmaxf(amax, fabsf(xi.z));
+ amax = fmaxf(amax, fabsf(xi.w));
+
+ // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+ for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+ }
+
+ float sum;
+ if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+ sum = xi.x + xi.y + xi.z + xi.w;
+
+ // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+ for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+ }
+ }
+
+ const float d = amax/127.f;
+ const float d_inv = d > 0 ? 1/d : 0.f;
+ char4 q;
+ q.x = roundf(xi.x*d_inv);
+ q.y = roundf(xi.y*d_inv);
+ q.z = roundf(xi.z*d_inv);
+ q.w = roundf(xi.w*d_inv);
+
+ // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+ char4 * yqs4 = (char4 *) y[ib].qs;
+ yqs4[iqs/4] = q;
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+ if (iqs % 16 != 0 || iqs >= 96) {
+ return;
+ }
+
+ y[ib].d2s6[2 + iqs/16] = sum;
+
+ if (iqs % 64 != 0) {
+ return;
+ }
+
+ y[ib].d2s6[iqs/64] = d;
+
+ return;
+ }
+
+ if (iqs % 32 != 0) {
+ return;
+ }
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+ y[ib].ds4[iqs/32] = make_half2(d, sum);
+ } else {
+ y[ib].d4[iqs/32] = d;
+ }
+}
+
void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
@@ -208,6 +300,35 @@ void quantize_mmq_q8_1_cuda(
}
}
+void quantize_mmq_q8_1_id_cuda(
+ const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
+ const ggml_type type_x, cudaStream_t stream) {
+
+ GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+
+ const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+ const dim3 num_blocks(block_num_x, kx1, 1);
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+ switch (mmq_get_q8_1_ds_layout(type_x)) {
+ case MMQ_Q8_1_DS_LAYOUT_D4:
+ quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D4>
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
+ break;
+ case MMQ_Q8_1_DS_LAYOUT_DS4:
+ quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_DS4>
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
+ break;
+ case MMQ_Q8_1_DS_LAYOUT_D2S6:
+ quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D2S6>
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) {
GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1);
GGML_ASSERT(src->type == GGML_TYPE_F32);
diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh
index 0be5bf0e..e1106164 100644
--- a/ggml/src/ggml-cuda/quantize.cuh
+++ b/ggml/src/ggml-cuda/quantize.cuh
@@ -30,5 +30,9 @@ void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
const ggml_type type_x, cudaStream_t stream);
+void quantize_mmq_q8_1_id_cuda(
+ const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
+ const ggml_type type_x, cudaStream_t stream);
+
// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream);