summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorEbey Abraham <ebey97@gmail.com>2023-12-18 17:27:47 +0000
committerGitHub <noreply@github.com>2023-12-18 19:27:47 +0200
commitb9e74f9bca5fdf7d0a22ed25e7a9626335fdfa48 (patch)
treeb150a0d4490627bfc9cdd758d08d026fc70b0882 /ggml-cuda.cu
parent3c04bf6da89eaf4c7d317e0518f0687dfcbf2de7 (diff)
llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (#4490)
* phi2 implementation * fix breaking change * phi-2 : various fixes * phi-2 : use layer norm eps * py : whitespaces * llama : fix meta KV override bug * convert : phi don't add BOS token * convert : revert "added_tokens_decoder" change * phi-2 : scale Q instead of KQ for better precision * ggml : fix NeoX rope to rotate just first n_dims * cuda : less diff in the rope_neox kernel * ggml : add ggml_mul_mat_set_prec ggml-ci * Update ggml-cuda.cu Co-authored-by: slaren <slarengh@gmail.com> * Update ggml-cuda.cu Co-authored-by: slaren <slarengh@gmail.com> * cuda : ggml_cuda_op_mul_mat_cublas support F32 precision * cuda : remove oboslete comment --------- Co-authored-by: Ebey Abraham <ebeyabraham@microsoft.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu117
1 files changed, 81 insertions, 36 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 0a63c1ec..d0f3d803 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;
- const int i = row*ncols + ib*n_dims + ic/2;
+ if (ib > 0) {
+ const int i = row*ncols + ib*n_dims + ic;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
@@ -7057,6 +7066,7 @@ inline void ggml_cuda_op_upscale(
(void) src1;
(void) dst;
+ (void) src1_dd;
}
inline void ggml_cuda_op_pad(
@@ -7073,6 +7083,7 @@ inline void ggml_cuda_op_pad(
(void) src1;
(void) dst;
+ (void) src1_dd;
}
inline void ggml_cuda_op_rms_norm(
@@ -7376,7 +7387,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
const int compute_capability = g_compute_capabilities[id];
- if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+ if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
@@ -8300,27 +8311,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
}
static __global__ void k_compute_batched_ptrs(
- const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
+ const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
- int ne12, int ne13,
- int ne23,
- int nb02, int nb03,
- int nb12, int nb13,
- int nb2, int nb3,
- int r2, int r3) {
- int i13 = blockIdx.x * blockDim.x + threadIdx.x;
- int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+ int64_t ne12, int64_t ne13,
+ int64_t ne23,
+ size_t nb02, size_t nb03,
+ size_t nb12, size_t nb13,
+ size_t nbd2, size_t nbd3,
+ int64_t r2, int64_t r3) {
+ int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+ int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
if (i13 >= ne13 || i12 >= ne12) {
return;
}
- int i03 = i13 / r3;
- int i02 = i12 / r2;
+ int64_t i03 = i13 / r3;
+ int64_t i02 = i12 / r2;
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8376,7 +8387,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
- half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+ half * dst_f16 = nullptr;
+ char * dst_t = nullptr;
+
+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+ cudaDataType_t cu_data_type = CUDA_R_16F;
+
+ // dst strides
+ size_t nbd2 = dst->nb[2];
+ size_t nbd3 = dst->nb[3];
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ const void * alpha = &alpha_f16;
+ const void * beta = &beta_f16;
+
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+ dst_t = (char *) dst_f16;
+
+ nbd2 /= sizeof(float) / sizeof(half);
+ nbd3 /= sizeof(float) / sizeof(half);
+ } else {
+ dst_t = (char *) dst_ddf;
+
+ cu_compute_type = CUBLAS_COMPUTE_32F;
+ cu_data_type = CUDA_R_32F;
+
+ alpha = &alpha_f32;
+ beta = &beta_f32;
+ }
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
@@ -8385,9 +8430,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
- const half alpha_f16 = 1.0f;
- const half beta_f16 = 0.0f;
-
#if 0
// use cublasGemmEx
{
@@ -8397,12 +8439,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
int i02 = i12 / r2;
CUBLAS_CHECK(
- cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
- (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
- &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
- CUBLAS_COMPUTE_16F,
+ alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
+ (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
+ beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
+ cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
@@ -8414,11 +8456,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
- (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
- &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
+ alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
+ (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+ beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
- CUBLAS_COMPUTE_16F,
+ cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
@@ -8435,24 +8477,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
- src0_as_f16, src1_as_f16, dst_f16,
+ src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
nb02, nb03,
nb12, nb13,
- dst->nb[2], dst->nb[3],
+ nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
- (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
- &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
+ alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+ (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+ beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
ne23,
- CUBLAS_COMPUTE_16F,
+ cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (ptrs_src_s != 0) {
@@ -8464,11 +8506,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
#endif
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+ ggml_cuda_pool_free(dst_f16, dst_as);
+ }
ggml_cuda_pool_free(src1_as_f16, src1_as);
- ggml_cuda_pool_free(dst_f16, dst_as);
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {