summaryrefslogtreecommitdiff
path: root/ggml-cuda/convert.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda/convert.cu')
-rw-r--r--ggml-cuda/convert.cu74
1 files changed, 37 insertions, 37 deletions
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index 18a31edc..ed4fa274 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -4,14 +4,14 @@
#define CUDA_Q8_0_NE_ALIGN 2048
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
- const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+ const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= k) {
return;
}
- const int ib = i/qk; // block index
+ const int64_t ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
template <bool need_check>
-static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
#if __CUDA_ARCH__ >= CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
- const int i = blockIdx.x;
+ const int64_t i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
- const int ib = 8*i + ir;
+ const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
template<typename dst_t>
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
- const int i = blockIdx.x;
+ const int64_t i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
- const int ib = 8*i + ir;
+ const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
@@ -313,14 +313,14 @@ template<typename dst_t>
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;
- const int i = blockIdx.x;
+ const int64_t i = blockIdx.x;
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
- const int tid = threadIdx.x;
- const int ip = tid/32; // ip is 0 or 1
- const int il = tid - 32*ip; // 0...32
- const int is = 8*ip + il/16;
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid/32; // ip is 0 or 1
+ const int64_t il = tid - 32*ip; // 0...32
+ const int64_t is = 8*ip + il/16;
dst_t * y = yy + i*QK_K + 128*ip + il;
@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
#else
// assume 32 threads
- const int tid = threadIdx.x;
- const int ip = tid/16; // 0 or 1
- const int il = tid - 16*ip; // 0...15
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid/16; // 0 or 1
+ const int64_t il = tid - 16*ip; // 0...15
dst_t * y = yy + i*QK_K + 16*ip + il;
@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
#endif
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
-static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
}
template<typename dst_t>
-static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
-static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
}
template <typename src_t, typename dst_t>
-static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+ const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
}
template <typename src_t, typename dst_t>
-static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}