summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-09 21:17:17 +0300
committerGitHub <noreply@github.com>2024-09-09 21:17:17 +0300
commit918ada20faf7747bbda6b78503b5d72a90157844 (patch)
tree8ea002fde74f65d0ee4f2e1857aa67e8613aee2f /ggml/src/ggml-cuda
parent8c86231f9306c81dc291c4c4a16f88bbc7c97793 (diff)
Add CUDA support for IQ1_TN (#45)
* iq1_tn: adding CUDA dequantize * iq1_tn: adding CUDA dot product * Delete commented out stuff * Delete forgotten TODO --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu133
-rw-r--r--ggml/src/ggml-cuda/convert.cuh2
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh4
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu87
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh4
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu3
7 files changed, 200 insertions, 40 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 9aff6c13..d75b219b 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -467,6 +467,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_TN> {
+ static constexpr int qk = QK_IQ1BN;
+ static constexpr int qr = QR1_BN;
+ static constexpr int qi = QI1_BN;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
static constexpr int qk = QK_IQ1BN;
static constexpr int qr = QR1_BN;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 70305404..03de64ef 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -447,6 +447,46 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
}
template<typename dst_t>
+static __global__ void dequantize_block_iq1_tn(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = *(const half *)cx;
+ const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(half));
+
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+//#define COMPUTE_VS(v) 3*v >> 8
+#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
+
+ const int tid = threadIdx.x;
+ const int il = tid/4; // 0...7
+ const int ib = tid%4; // 0...3
+ dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
+ const int i16 = il/2;
+ int64_t i = QK_K/QK_IQ1BN * (ii - (row*n_per_row)/QK_K) + ib;
+ uint8_t q = x[i].ql[3*i16+2*(il%2)];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[2*(il%2)+j] = scale*(vs - 1);
+ }
+ q = x[i].ql[3*i16+1];
+ for (int j = 0; j < 2; ++j) {
+ uint8_t v = k_mult[3*(il%2)+j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[5*(1-(il%2))+j] = scale*(vs-1);
+ }
+ uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[7] = scale*(vs - 1);
+
+#undef COMPUTE_VS
+}
+
+template<typename dst_t>
static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {
const int64_t ii = blockIdx.x;
@@ -675,12 +715,14 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
-static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
@@ -692,149 +734,181 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
}
template<typename dst_t>
-static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
-static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb64 = k / QK_IQ1BN;
const int nb = (k + 255) / 256;
dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
}
template<typename dst_t>
-static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq1_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row);
+ const int nb = (k + 255) / 256;
+ dequantize_block_iq1_tn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb64 = k / QK_IQ1BN;
const int nb = (k + 255) / 256;
dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
}
template<typename dst_t>
-static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq6_k<<<nb, 32, 0, stream>>>(vx, y);
}
@@ -853,7 +927,8 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
}
template <typename src_t, typename dst_t>
-static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
@@ -899,6 +974,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
return dequantize_row_iq1_bn_cuda;
+ case GGML_TYPE_IQ1_TN:
+ return dequantize_row_iq1_tn_cuda;
case GGML_TYPE_IQ2_BN:
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
@@ -962,6 +1039,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
return dequantize_row_iq1_bn_cuda;
+ case GGML_TYPE_IQ1_TN:
+ return dequantize_row_iq1_tn_cuda;
case GGML_TYPE_IQ2_BN:
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh
index 5394be9f..1fb53900 100644
--- a/ggml/src/ggml-cuda/convert.cuh
+++ b/ggml/src/ggml-cuda/convert.cuh
@@ -3,7 +3,7 @@
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
template<typename T>
-using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t nrows, int64_t n_per_row, cudaStream_t stream);
typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index e4021764..0bcd1ff7 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -624,7 +624,7 @@ void launch_fattn(
if (need_f16_K && K->type != GGML_TYPE_F16) {
K_f16.alloc(ggml_nelements(K));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
- to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+ to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
K_data = (char *) K_f16.ptr;
const size_t bs = ggml_blck_size(K->type);
@@ -638,7 +638,7 @@ void launch_fattn(
if (need_f16_V && V->type != GGML_TYPE_F16) {
V_f16.alloc(ggml_nelements(V));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
- to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+ to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
V_data = (char *) V_f16.ptr;
const size_t bs = ggml_blck_size(V->type);
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index c567ad1a..a890f6b3 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -8,6 +8,11 @@
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+// Reminder:
+// constexpr int qk = ggml_cuda_type_traits<type>::qk;
+// constexpr int qi = ggml_cuda_type_traits<type>::qi;
+// constexpr int vdr = get_vdr_mmvq(type);
+
namespace {
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -16,7 +21,7 @@ __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -50,7 +55,8 @@ __global__ void iqk_mul_mat_vec_q(
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
- tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+ tmp[j][i] += vec_dot_q_cuda((const void *)((const char *)vx + (row0 + i)*row_size),
+ &y[j*blocks_per_col_y + kby], kbx, kqs);
}
}
}
@@ -129,30 +135,32 @@ void iqk_mul_mat_vec_q_cuda(
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
+ const int64_t row_size = ggml_row_size(type, ncols_x);
+
switch (ncols_y) {
case 1:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 2:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 3:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 4:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 5:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 6:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 7:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 8:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
default:
GGML_ASSERT(false);
@@ -540,6 +548,58 @@ static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
}
+static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ float scale = *(const half *)vbq;
+ const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(half)) + kbx;
+
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ // iqs is 0 or 1
+
+ int sumi = 0;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ const int * q8 = (const int *)bq8_1[iqs].qs;
+ int val[4];
+ for (int l = 0; l < 2; ++l) {
+ int8_t * a = (int8_t *)val;
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ }
+ }
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
+ }
+#else
+ const int8_t * q8 = bq8_1[iqs].qs;
+ for (int l = 0; l < 2; ++l) {
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[j]*(vs - 1);
+ }
+ q8 += 5;
+ }
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[0]*(vs - 1);
+ q8++;
+ }
+#endif
+ return __low2float(bq8_1[iqs].ds) * scale * sumi;
+}
+
} // namespace
void mul_mat_vec_iq2_k_q8_1_cuda(
@@ -583,3 +643,10 @@ void mul_mat_vec_iq2_tn_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+
+void mul_mat_vec_iq1_tn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_TN, 1, vec_dot_iq1_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 7af8e570..7fb76ff6 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -24,3 +24,7 @@ void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+void mul_mat_vec_iq1_tn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 2586ab7e..5f932fef 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -429,6 +429,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ2_TN:
mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ1_TN:
+ mul_mat_vec_iq1_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;