summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-05 19:13:51 +0300
committerGitHub <noreply@github.com>2025-06-05 19:13:51 +0300
commiteded4e20d4decdc6e8c18e645fd1db0833ad251d (patch)
treeb97f31f30d263b3d2ca610d75bcdd58793f7e04d
parent8ffad187abbb93b74db8ef813b6fdceec80e02b0 (diff)
IQ1_M_R4 CUDA implementation (#494)
* iq1_m_r4: CUDA dequantize * iq1_m_r4: CUDA dequantize --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda.cu1
-rw-r--r--ggml/src/ggml-cuda/convert.cu52
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu50
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh5
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu4
5 files changed, 109 insertions, 3 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index de9816fe..6bdeb465 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -3477,6 +3477,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
+ case GGML_TYPE_IQ1_M_R4:
return true;
default:
return false;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index b2f77e09..01b7250e 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -542,7 +542,7 @@ static __global__ void dequantize_block_iq1_s_r4(const void * __restrict__ vx, d
const int ib = tid%8; // 0...7
const half * dptr = (const half *)((const char *)vx + 4*row4*row_size);
- const float d = (float)dptr[ir];
+ const float d = __half2float(dptr[ir]);
const block_iq1_s_r4 * x = (const block_iq1_s_r4 *)(dptr + 4) + ibl;
dst_t * y = yy + 256*ii + 32*ib + 8*il;
@@ -562,6 +562,42 @@ static __global__ void dequantize_block_iq1_s_r4(const void * __restrict__ vx, d
}
template<typename dst_t>
+static __global__ void dequantize_block_iq1_m_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+
+ int64_t nblock = n_per_row/32;
+ int64_t row = (8*ii)/nblock;
+ int64_t row4 = row/4;
+ int64_t ir = row%4;
+ int64_t ibl = (8*ii)%nblock;
+
+ const int tid = threadIdx.x;
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+
+ const half * dptr = (const half *)((const char *)vx + 4*row4*row_size);
+ const float d = __half2float(dptr[ir]);
+ const block_iq1_m_r4 * x = (const block_iq1_m_r4 *)(dptr + 4) + ibl;
+ dst_t * y = yy + 256*ii + 32*ib + 8*il;
+
+ uint8_t qh = x[ib].qh[4*(il/2)+ir] >> 4*(il%2);
+ float dl = d*((x[ib].scales[ir] >> 4*(il/2)) & 0xf);
+ float delta = dl * (qh & 0x8 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA);
+
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[ib].qs[4*il+ir] | ((qh & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+
+ if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
+ for (int j = 0; j < 8; ++j) y[j] = __float2bfloat16(dl*q[j] + delta);
+ } else {
+ for (int j = 0; j < 8; ++j) y[j] = dl*q[j] + delta;
+ }
+}
+
+template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
@@ -1442,6 +1478,14 @@ static void dequantize_row_iq1_s_r4_cuda(const void * vx, dst_t * y, const int64
}
template<typename dst_t>
+static void dequantize_row_iq1_m_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_M_R4, n_per_row);
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq1_m_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
+}
+
+template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
@@ -1696,6 +1740,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
return dequantize_row_iq5_ks_r4_cuda<nv_bfloat16>;
case GGML_TYPE_IQ1_S_R4:
return dequantize_row_iq1_s_r4_cuda<nv_bfloat16>;
+ case GGML_TYPE_IQ1_M_R4:
+ return dequantize_row_iq1_m_r4_cuda<nv_bfloat16>;
default:
return nullptr;
}
@@ -1746,6 +1792,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_S_R4:
return dequantize_row_iq1_s_r4_cuda;
+ case GGML_TYPE_IQ1_M_R4:
+ return dequantize_row_iq1_m_r4_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
@@ -1839,6 +1887,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_S_R4:
return dequantize_row_iq1_s_r4_cuda;
+ case GGML_TYPE_IQ1_M_R4:
+ return dequantize_row_iq1_m_r4_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 747af5a7..e5a224b4 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -36,6 +36,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K_R4> {
static constexpr int qi = QI5_XS;
};
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M_R4> {
+ static constexpr int qk = 32;
+ static constexpr int qr = 2;
+ static constexpr int qi = 4;
+};
+
// Reminder:
// constexpr int qk = ggml_cuda_type_traits<type>::qk;
// constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -338,7 +345,6 @@ __device__ __forceinline__ void vec_dot_iq4_ks_r4_q8_1(
}
}
-// TODO
__device__ __forceinline__ void vec_dot_iq1_s_r4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
@@ -356,7 +362,7 @@ __device__ __forceinline__ void vec_dot_iq1_s_r4_q8_1(
for (int k = 0; k < 4; ++k) minus = ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+k], minus);
for (int i = 0; i < 4; ++i) {
- float dl = (float)dptr[i]*(2*((bq1->qh[i] >> 12) & 7) + 1) * d8;
+ float dl = __half2float(dptr[i])*(2*((bq1->qh[i] >> 12) & 7) + 1) * d8;
float ml = dl * (bq1->qh[i] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA);
grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i] | (((bq1->qh[i] >> 3*iqs) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
@@ -370,6 +376,38 @@ __device__ __forceinline__ void vec_dot_iq1_s_r4_q8_1(
}
}
+__device__ __forceinline__ void vec_dot_iq1_m_r4_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ const half * dptr = (const half *)vbq;
+ const block_iq1_m_r4 * bq1 = (const block_iq1_m_r4 *)(dptr + 4) + kbx;
+
+ // iqs is 0 or 2
+ const float d8 = __low2float(bq8_1->ds);
+ const int32_t * q8 = (const int *)bq8_1->qs;
+
+ int32_t grid32[2];
+ const int * igrid = (const int *)grid32;
+
+ int minus1 = ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+0], ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+1], 0));
+ int minus2 = ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+2], ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+3], 0));
+
+ for (int i = 0; i < 4; ++i) {
+ float dl = __half2float(dptr[i])*((bq1->scales[i] >> 4*(iqs/2)) & 0xf) * d8;
+ float ml1 = dl * (bq1->qh[4*(iqs/2)+i] & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA);
+ float ml2 = dl * (bq1->qh[4*(iqs/2)+i] & 0x80 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA);
+ grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i] | ((bq1->qh[4*(iqs/2)+i] & 0x07) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+ int sumi = ggml_cuda_dp4a(igrid[0], q8[4*(iqs/2)+0], ggml_cuda_dp4a(igrid[1], q8[4*(iqs/2)+1], 0));
+ grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i+4] | ((bq1->qh[4*(iqs/2)+i] & 0x70) << 4)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+ sumi = ggml_cuda_dp4a(igrid[0], q8[4*(iqs/2)+2], ggml_cuda_dp4a(igrid[1], q8[4*(iqs/2)+3], sumi));
+ result[i] += dl * sumi + ml1 * minus1 + ml2*minus2;
+ }
+}
+
#define VDR_IQ4_KS_Q8_1_MMVQ 4
#define VDR_IQ4_KS_Q8_1_MMQ 4
@@ -1131,6 +1169,14 @@ void mul_mat_vec_iq1_s_r4_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S_R4, 2, vec_dot_iq1_s_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
+void mul_mat_vec_iq1_m_r4_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M_R4, 2, vec_dot_iq1_m_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
void mul_mat_vec_iq5_k_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 1e4257e8..17bf5ad2 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -95,3 +95,8 @@ void mul_mat_vec_iq1_s_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq1_m_r4_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index cc00d278..73caabab 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -563,6 +563,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ1_S_R4:
mul_mat_vec_iq1_s_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
+ case GGML_TYPE_IQ1_M_R4:
+ mul_mat_vec_iq1_m_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
default:
GGML_ABORT("fatal error");
break;
@@ -683,6 +686,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
+ case GGML_TYPE_IQ1_M_R4:
return true;
default:
return false;