summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-07 07:56:09 +0200
committerGitHub <noreply@github.com>2024-08-07 07:56:09 +0200
commita9f302ebe2373321c12b01d8760904901aa064a4 (patch)
tree7953bbff2ebd6bf9130cea52d17995aea3cd65d5 /ggml/src/ggml-cuda
parentb409c153636d27473970abd3a9c9400b6287d400 (diff)
Adding IQ2_TN for use with ternary models (#13)
* iq2_tn: TriLM specific 2.0625 bpw quantization Quantize/dequantize/scale dot product. I get 46 t/s for the TriLM-3.9B with any SIMD! Finally a compiler doing a decent job auto-vectorizing the scalar implementation. * iq2_tn: AVX512 Just reusing the k-quants template gets us to PP-512 = 376 t/s, TG-128 = 47.6 t/s for TriLM-3.9B. * iq2_tn: AVX512 With this tweak we get to PP-512 = 431 t/s. * iq2_tn: AVX512 With this tweak we get TG-128 = 19.58 / 35.18 t/s for 1 / 2 threads. At 4 threads we saturate at 48.41 t/s, and then performance slowly degrades with increasing number of threads. * iq2_tn: AVX2 PP512 = 440 t/s on the Ryzen-5975WX. We should be able to do better. * iq2_tn: initial NEON version * iq2_tn: NEON For TriLM-3.9B running on the M2-Max we get PP-512 = 193.5 t/s, TG-128 = 75.5 t/s. This is in line with what we have for iq2_bn ant 3.3B Bitnet. * iq2_tn: Metal For TriLM-3.9B on a 30-core M2-Max we get PP-512 = 890 t/s, TG-128 = 98.5 t/s. * iq2_tn: CUDA For TriLM-3.9B running on RTX-4080 we get PP-512 = 9936 t/s, TG-128 = 299.2 t/s. * iq2_tn: AVX2 PP improvement We now get PP-512 = 490.73 t/s for TriLM-3.9B on the Ryzen-5975WX. We have PP-512 = 636.61 t/s for Bintnet-3B quantized with iq2_bn. Bintnet-3B is actually 3.4B, TriLM-3.9B is 3.99B, so we would expect 3.43/3.99 * 636 = 546 t/s, so it seems we still have something that is not quite optimal in iq2_tn. * iq2_tn: small NEON improvement For TriLM-3.9B we now get PP-512 = 206.6 t/s and TG-128 = 76.4 t/s. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu31
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu42
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh4
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu3
5 files changed, 87 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index fbc52aa9..c18e865a 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -656,6 +656,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_K;
+ static constexpr int qi = QI2_K;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index ed7e4bd0..47ab92f0 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -154,6 +154,27 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
}
template<typename dst_t>
+static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_tn * x = (const block_iq2_tn *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t n = tid/32;
+ const int64_t l = tid - 32*n;
+ const int64_t is = 8*n + l/16;
+
+ const uint8_t q = x[i].qs[32*n + l];
+ dst_t * y = yy + i*QK_K + 128*n;
+
+ float d = __half2float(x[i].d);
+ y[l+ 0] = d * ((q >> 0) & 3) - d;
+ y[l+32] = d * ((q >> 2) & 3) - d;
+ y[l+64] = d * ((q >> 4) & 3) - d;
+ y[l+96] = d * ((q >> 6) & 3) - d;
+}
+
+template<typename dst_t>
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
@@ -647,6 +668,12 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k
}
template<typename dst_t>
+static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -812,6 +839,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_IQ2_TN:
+ return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
@@ -871,6 +900,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_IQ2_TN:
+ return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index acb495d1..8def1547 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -469,6 +469,41 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
}
+#define VDR_IQ2_TN_Q8_1_MMVQ 1
+#define VDR_IQ2_TN_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx;
+
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
+
+ const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs;
+ int v = q16[0] | (q16[1] << 16);
+
+ float sumf = 0;
+ for (int i = 0; i < QR2_K; ++ i) {
+ int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
+ float d8 = __low2float(bq8_1[bq8_offset + i].ds);
+ sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
+ v >>= 2;
+ }
+ return __half2float(bq2->d) * sumf;
+
+ //float sumf_d = 0;
+ //float sumf_m = 0;
+ //for (int i = 0; i < QR2_K; ++ i) {
+ // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
+ // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds);
+ // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0);
+ // sumf_m += d8.y;
+ // v >>= 2;
+ //}
+ //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m);
+
+}
+
} // namespace
void mul_mat_vec_iq2_k_q8_1_cuda(
@@ -499,3 +534,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+void mul_mat_vec_iq2_tn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 9a33af0d..3dc5f41c 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -16,3 +16,7 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+void mul_mat_vec_iq2_tn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 56bf3ebe..428d822f 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -426,6 +426,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ2_BN:
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ2_TN:
+ mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;