summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-18 07:44:43 +0100
committerGitHub <noreply@github.com>2025-03-18 07:44:43 +0100
commitf4ebf13b6a63ac1367bc392e24566d71c0b4c5b9 (patch)
tree9bf4e7211c0b7bb119c84d9156425e6e756230f8
parentbdcae905c4cb0de1025a45a2bd6c2e646cc22be7 (diff)
Fix #261 (#262)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda.cu67
1 files changed, 34 insertions, 33 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index a8f7383a..01f98594 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -1267,40 +1267,41 @@ static void ggml_cuda_op_mul_mat_cublas(
#ifdef GGML_CUDA_IQK_FORCE_BF16
if (ggml_is_quantized(src0->type) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type);
- GGML_ASSERT(to_bf16_cuda != nullptr);
- size_t ne = row_diff*ne00;
- ggml_cuda_pool_alloc<nv_bfloat16> src0_as_bf16(ctx.pool(id), ne);
- to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream);
-
- ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
- if (src1->type != GGML_TYPE_BF16) {
- size_t ne = src1_ncols*ne10;
- src1_as_bf16.alloc(ne);
- to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
- GGML_ASSERT(to_bf16_cuda != nullptr);
- to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream);
+ if (to_bf16_cuda) {
+ size_t ne = row_diff*ne00;
+ ggml_cuda_pool_alloc<nv_bfloat16> src0_as_bf16(ctx.pool(id), ne);
+ to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream);
+
+ ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
+ if (src1->type != GGML_TYPE_BF16) {
+ size_t ne = src1_ncols*ne10;
+ src1_as_bf16.alloc(ne);
+ to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
+ GGML_ASSERT(to_bf16_cuda != nullptr);
+ to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream);
+ }
+ const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
+ const nv_bfloat16 * src0_ptr = src0_as_bf16.get();
+
+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+ CUBLAS_CHECK(
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
+ src1_ptr, CUDA_R_16BF, ne10,
+ &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
+ CUBLAS_COMPUTE_32F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
+ to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream);
+ return;
}
- const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
- const nv_bfloat16 * src0_ptr = src0_as_bf16.get();
-
- ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
-
- const float alpha_f32 = 1.0f;
- const float beta_f32 = 0.0f;
-
- CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
- CUBLAS_CHECK(
- cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
- row_diff, src1_ncols, ne10,
- &alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
- src1_ptr, CUDA_R_16BF, ne10,
- &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
- CUBLAS_COMPUTE_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
- to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream);
- return;
}
#endif