From e73ae1f6d31074f774741a592382ec62a9de6dbf Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 19 Jun 2024 19:51:39 +0300 Subject: bitnet(scale in a separate tensor): mul -> scale on CUDA On CUDA we do not have access to the tensor data until we hit the kernel. That's why this hack. In any case, iq2_bn goes back up to 228 t/s, which is close to the 234 t/s we have without the extra scale operation. PP is 9400 t/s, down from 9600 t/s, but better than the 9200 t/s we get without making the mul -> scale replacement. --- ggml-cuda/binbcast.cu | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ggml-cuda/binbcast.cu b/ggml-cuda/binbcast.cu index 19b08b74..76cc01b2 100644 --- a/ggml-cuda/binbcast.cu +++ b/ggml-cuda/binbcast.cu @@ -271,7 +271,43 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +static __global__ void scale_f32_l(const float * x, float * dst, const void * data, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float * scale = (const float *)data; + dst[i] = scale[0] * x[i]; +} + +static void scale_f32_cuda_l(const float * x, float * dst, const void * data, const int k, cudaStream_t stream) { + constexpr int CUDA_SCALE_BLOCK_SIZE = 512; //256; + const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32_l<<>>(x, dst, data, k); +} + +void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scale; + memcpy(&scale, dst->src[1]->data, sizeof(float)); + + scale_f32_cuda_l(src0_d, dst_d, dst->src[1]->data, ggml_nelements(src0), stream); +} + void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + if (ggml_nelements(dst->src[1]) == 1 && dst->src[1]->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32) { + ggml_cuda_op_scale_tensor(ctx, dst); + return; + } ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } -- cgit v1.2.3