diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 53 |
1 files changed, 42 insertions, 11 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c34e96ab..be75cb79 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -223,6 +223,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] + y[i]; } +static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = __hadd(x[i], __float2half(y[i])); +} + static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1459,6 +1468,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); } +static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); +} + static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky); @@ -1941,7 +1955,7 @@ inline void ggml_cuda_op_add( float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t & cudaStream_main){ - GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr); @@ -1949,7 +1963,13 @@ inline void ggml_cuda_op_add( const int64_t i01_diff = i01_high - i01_low; // compute - add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main); + } else { + GGML_ASSERT(false); + } CUDA_CHECK(cudaGetLastError()); (void) src1; @@ -2547,8 +2567,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true); + // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op. + // Due to flatten_rows == true this does in practice not make a difference however. + // Better solution would be nice but right now that would require disproportionate changes. + GGML_ASSERT( + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + src1->type == GGML_TYPE_F32 && + (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16)); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true); } void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2801,7 +2827,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { delete extra; } -void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) { if (scratch && g_scratch_size == 0) { return; } @@ -2810,11 +2836,11 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src0->op; if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { - ggml_cuda_assign_buffers_impl(tensor->src0, scratch); + ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace); } } if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) { - ggml_cuda_assign_buffers_impl(tensor->src1, scratch); + ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace); } tensor->backend = GGML_BACKEND_GPU; @@ -2822,11 +2848,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { memset(extra, 0, sizeof(*extra)); const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) || - tensor->op == GGML_OP_VIEW; + tensor->op == GGML_OP_VIEW || + force_inplace; const size_t size = ggml_nbytes(tensor); CUDA_CHECK(cudaSetDevice(g_main_device)); - if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) { + if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t offset = 0; @@ -2865,11 +2892,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { } void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, true); + ggml_cuda_assign_buffers_impl(tensor, true, false); } void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false); + ggml_cuda_assign_buffers_impl(tensor, false, false); +} + +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, true); } void ggml_cuda_set_main_device(int main_device) { |