diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-07-07 07:23:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-07 07:23:12 +0200 |
commit | 4c0b66026619cf51f45249181bf2cc1de8cd6884 (patch) | |
tree | 93c1b5474296180dda5eaf302ffa4ff615e4d62f | |
parent | 6f3a3ba7e249cd689cb1ab0376e6504fb6cd49e7 (diff) |
CUDA: small PP performance improvement for MoE models (#589)
* Trying to implement quantized fmoe - not working yet
* This works, but is slower than the non-working version
* quantize_mmq_q8_1_id
* Minor
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda.cu | 44 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/quantize.cu | 121 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/quantize.cuh | 4 |
3 files changed, 161 insertions, 8 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e0035c7a..7fb67738 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2186,6 +2186,7 @@ struct mmid_row_mapping { int32_t i2; }; +template <typename data_t = float> static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous, const mmid_row_mapping * __restrict__ row_mapping, int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) { @@ -2194,8 +2195,8 @@ static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_or const int32_t i11 = row_mapping[i].i1 % ne11; const int32_t i12 = row_mapping[i].i2; - float * src_row_contiguous = (float *)(src_contiguous + i*nb11); - const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12); + data_t * src_row_contiguous = (data_t *)(src_contiguous + i*nb11); + const data_t * src_row_original = (const data_t *)(src_original + i11*nb11 + i12*nb12); for (int j = threadIdx.x; j < ne10; j += blockDim.x) { src_row_contiguous[j] = src_row_original[j]; @@ -2673,6 +2674,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } } } else { + //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]); + ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool()); + bool use_quantized_src1 = false; + int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; + if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); + src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + src1_quantized.alloc(src1_quantized_size); + use_quantized_src1 = true; + } ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); @@ -2704,7 +2716,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor if (num_src1_rows == 0) continue; size_t mapping_offset = cum_moe_counts[i02]; - { + if (use_quantized_src1) { + quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), + src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_quantized.get(); + } + else { dim3 block_dims(std::min((unsigned int)ne10, 768u)); dim3 grid_dims(num_src1_rows); k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>( @@ -2719,9 +2737,9 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(nb1 == sizeof(float)*ne0); src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = nb11; - src1_row.nb[2] = num_src1_rows*nb11; - src1_row.nb[3] = num_src1_rows*nb11; + src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; + src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; + src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; @@ -2729,11 +2747,21 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dst_row.nb[3] = num_src1_rows*nb1; dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + } CUDA_CHECK(cudaGetLastError()); dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + } CUDA_CHECK(cudaGetLastError()); ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 953eb9d9..52d8787d 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -166,6 +166,98 @@ static __global__ void quantize_mmq_q8_1( } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +template <mmq_q8_1_ds_layout ds_layout> +static __global__ void quantize_mmq_q8_1_id( + const float * __restrict__ x, void * __restrict__ vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + + if (ix0 >= kx0_padded) { + return; + } + + const float4 * x4 = (const float4 *) x; + + const mmid_row_mapping * mapping = (const mmid_row_mapping *)row_mapping; + const int64_t ii = mapping[blockIdx.y].i2; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel + const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = ix0 < kx0 ? x4[(ii*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); + + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + } + + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; + + // Exchange calculate sum across vals_per_sum/4 threads. +#pragma unroll + for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); + } + } + + const float d = amax/127.f; + const float d_inv = d > 0 ? 1/d : 0.f; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + y[ib].d2s6[iqs/64] = d; + + return; + } + + if (iqs % 32 != 0) { + return; + } + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); + } else { + y[ib].d4[iqs/32] = d; + } +} + void quantize_row_q8_1_cuda( const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { @@ -208,6 +300,35 @@ void quantize_mmq_q8_1_cuda( } } +void quantize_mmq_q8_1_id_cuda( + const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); + + const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(block_num_x, kx1, 1); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_x)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D4> + <<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_DS4> + <<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D2S6> + <<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + + void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) { GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1); GGML_ASSERT(src->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 0be5bf0e..e1106164 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -30,5 +30,9 @@ void quantize_mmq_q8_1_cuda( const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream); +void quantize_mmq_q8_1_id_cuda( + const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + // For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1 void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream); |