diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 257 |
1 files changed, 55 insertions, 202 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b35fcb7f..5fd8a87e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5956,148 +5956,30 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -template <bool vals_smem, int ncols_template, int block_size_template, bool need_check> -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX - const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; - const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; +template <bool vals_smem, int ncols_template, int block_size_template> +static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; const int rowx = blockIdx.x; - const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - extern __shared__ half data_soft_max_f16[]; - half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication - // (shared memory) buffer to cache values between iterations: - half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data); - // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead - // in that case col_smem == col_data must be enforced to avoid race conditions - - half2 max_val = make_half2(-INFINITY, -INFINITY); - -#pragma unroll - for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { - const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; - const int col_smem = vals_smem ? col0 + tid : col_data; - - const int ix = rowx*ncols_data + col_data; - const int iy = rowy*ncols_data + col_data; - - half2 val; - if (need_check && col_data + 0 >= ncols_data) { - val.x = -INFINITY; - } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); - } - if (need_check && col_data + WARP_SIZE >= ncols_data) { - val.y = -INFINITY; - } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); - } - if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { - vals[col_smem] = val; - } - max_val = __hmax2(max_val, val); - } - - // find the max value in the block - max_val = warp_reduce_max(max_val); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = __hmax(max_val.x, max_val.y); - } - __syncthreads(); - - max_val = __half2half2(buf_iw[lane_id]); - max_val = warp_reduce_max(max_val); - } else { - max_val = __half2half2(__hmax(max_val.x, max_val.y)); - } - - half2 tmp = make_half2(0.0f, 0.0f); // partial sums + float slope = 0.0f; -#pragma unroll - for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { - const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id; - - if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) { - break; - } - - const half2 val = h2exp(vals[col_smem] - max_val); - - tmp += val; - vals[col_smem] = val; - } - - // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = 0.0f; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = tmp.x + tmp.y; - } - __syncthreads(); - - tmp = __half2half2(buf_iw[lane_id]); - tmp = warp_reduce_sum(tmp); - } else { - tmp = __half2half2(tmp.x + tmp.y); - } - - const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp; - -#pragma unroll - for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { - const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; - const int col_smem = vals_smem ? col0 + tid : col_data; - - const int idst = rowx*ncols_data + col_data; - const half2 result = vals[col_smem] * inv_sum; - - if (need_check && col_data + 0 >= ncols_data) { - return; - } - dst[idst] = result.x; + // ALiBi + if (max_bias > 0.0f) { + const int h = rowx/nrows_y; // head index - if (need_check && col_data + WARP_SIZE >= ncols_data) { - return; - } + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - dst[idst + WARP_SIZE] = result.y; + slope = powf(base, exp); } -#else - (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale; - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -} - -template <bool vals_smem, int ncols_template, int block_size_template> -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; - - const int tid = threadIdx.x; - const int rowx = blockIdx.x; - const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - - const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; - - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication @@ -6117,7 +5999,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col]; + vals[col] = val; max_val = max(max_val, val); } @@ -7589,89 +7472,53 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { - int nth = WARP_SIZE; - while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; - const dim3 block_dims(nth, 1, 1); - const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half); - static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - if (shmem <= g_device_caps[g_main_device].smpb) { - switch (ncols_x) { - case 32: - soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 64: - soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 128: - soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 256: - soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 512: - soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 1024: - soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 2048: - soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 4096: - soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - default: - soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - break; - } - } else { - const size_t shmem_low = WARP_SIZE*sizeof(half); - soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale); - } -} - -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + if (shmem < g_device_caps[g_main_device].smpb) { switch (ncols_x) { case 32: - soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } @@ -9090,30 +8937,36 @@ static void ggml_cuda_op_soft_max( GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src0->ne[1]; - float scale = 1.0f; - memcpy(&scale, dst->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX -#ifdef GGML_CUDA_F16 - const bool use_f16_soft_max = true; -#else - const bool use_f16_soft_max = false; -#endif // GGML_CUDA_F16 -#else - const bool use_f16_soft_max = false; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); - } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + // positions tensor + float * src2_dd = dst_dd; // default to avoid null checks in the kernel + cuda_pool_alloc<float> src2_f; + + ggml_tensor * src2 = dst->src[2]; + const bool use_src2 = src2 != nullptr; + + if (use_src2) { + const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU; + ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr; + + if (src2_on_device) { + src2_dd = (float *) src2_extra->data_device[g_main_device]; + } else { + src2_dd = src2_f.alloc(ggml_nelements(src2)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream)); + } } - (void) dst; + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream); } static void ggml_cuda_op_scale( |