summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-25 13:08:43 +0200
committerGitHub <noreply@github.com>2024-10-25 13:08:43 +0200
commit6b968f38946117552ffed300771c44ba9b39d3e4 (patch)
treedc6b0df69f31ea77d9941d6798a4ef411c688080 /ggml/src/ggml-cuda
parent9114078959b404899fd67e1af45f0dcbee51b47f (diff)
Bitnet changes (#106)
* Adapting iq2_bn to work without separate scale tensors Why? It is becoming burdensome to maintain the special Bitnet conversion in convert_hf_to_gguf.py, so I thnk it is better to make iq1_bn and iq2_bn just work with the mainline conversion script (which does not generate scales). * Adapting iq1_bn to work without separate scale tensors * Adapting iq2_bn: CUDA dequantize * Adapting iq2_bn: CUDA works * Adapting iq1_bn: CUDA works * Adapting iq1_bn, iq2_bn: NEON * Adapting iq1_bn, iq2_bn: Metal Dequantize works, but there is still something wrong with the dot products. * WIP Absoolutely don't see what is wrong with the iq1_bn and iq2_bn vector dot product kernels. * Remove iq1_tn and iq2_tn - Part 1 Now that iq1_bn and iq2_bn have per row scales, there is no reason to also have iq1_tn and iq2_tn. * Remove iq1_tn and iq2_tn - Part 2 * Bitnet: use the standard llm_build_kv to build self attention My main motivation was to enable FA. But FA does not work anyway because head size is 100 for the Botnet ternary models (and I had forgotten this little detail). * Revert "Avoid rebuild of GGML graph for each token (#98)" This reverts commit f2d315b46f7aacc7df4b86bd8acba387b30e11ca. As far as I can tell, the commit breaks Metal TG. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/binbcast.cu2
-rw-r--r--ggml/src/ggml-cuda/common.cuh14
-rw-r--r--ggml/src/ggml-cuda/convert.cu147
-rw-r--r--ggml/src/ggml-cuda/fattn.cu5
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu84
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh11
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu22
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh89
8 files changed, 96 insertions, 278 deletions
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index 62d115f1..5abbd43c 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -288,7 +288,7 @@ static void scale_f32_cuda_l(const float * x, float * dst, const void * data, co
scale_f32_l<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, data, k);
}
-void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+static void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index a5658a24..2eba527f 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -474,13 +474,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> {
};
template<>
-struct ggml_cuda_type_traits<GGML_TYPE_IQ1_TN> {
- static constexpr int qk = QK_IQ1BN;
- static constexpr int qr = QR1_BN;
- static constexpr int qi = QI1_BN;
-};
-
-template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
static constexpr int qk = QK_IQ1BN;
static constexpr int qr = QR1_BN;
@@ -488,13 +481,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
};
template<>
-struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> {
- static constexpr int qk = QK_K;
- static constexpr int qr = QR2_K;
- static constexpr int qi = QI2_K;
-};
-
-template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index e9d15b5d..b9baee1b 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -184,30 +184,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
}
template<typename dst_t>
-static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy,
- int64_t n_per_row, int64_t row_size) {
-
- int64_t ii = blockIdx.x;
- int64_t row = (QK_K * ii) / n_per_row;
- const char * cx = (const char *)vx + row * row_size;
- float d = *(const float *)cx;
- const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float));
- int64_t i = ii - (row*n_per_row)/QK_K;
-
- const int64_t tid = threadIdx.x;
- const int64_t n = tid/32;
- const int64_t l = tid - 32*n;
-
- const uint8_t q = x[i].qs[32*n + l];
- dst_t * y = yy + ii*QK_K + 128*n;
-
- y[l+ 0] = d * ((q >> 0) & 3) - d;
- y[l+32] = d * ((q >> 2) & 3) - d;
- y[l+64] = d * ((q >> 4) & 3) - d;
- y[l+96] = d * ((q >> 6) & 3) - d;
-}
-
-template<typename dst_t>
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
@@ -481,101 +457,72 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
}
template<typename dst_t>
-static __global__ void dequantize_block_iq1_tn(const void * __restrict__ vx, dst_t * __restrict__ yy,
- int64_t n_per_row, int64_t row_size) {
-
- int64_t ii = blockIdx.x;
- int64_t row = (QK_K * ii) / n_per_row;
- const char * cx = (const char *)vx + row * row_size;
- float scale = *(const half *)cx;
- const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(half));
-
- static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
-
-//#define COMPUTE_VS(v) 3*v >> 8
-#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
+static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ int64_t n_per_row, int64_t row_size, int64_t nrows) {
+ int64_t ii = 256*blockIdx.x;
const int tid = threadIdx.x;
const int il = tid/4; // 0...7
const int ib = tid%4; // 0...3
- dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
- const int i16 = il/2;
- int64_t i = QK_K/QK_IQ1BN * (ii - (row*n_per_row)/QK_K) + ib;
- uint8_t q = x[i].ql[3*i16+2*(il%2)];
- for (int j = 0; j < 5; ++j) {
- uint8_t v = k_mult[j]*q;
- int8_t vs = COMPUTE_VS(v);
- y[2*(il%2)+j] = scale*(vs - 1);
- }
- q = x[i].ql[3*i16+1];
- for (int j = 0; j < 2; ++j) {
- uint8_t v = k_mult[3*(il%2)+j]*q;
- int8_t vs = COMPUTE_VS(v);
- y[5*(1-(il%2))+j] = scale*(vs-1);
- }
- uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
- int8_t vs = COMPUTE_VS(v);
- y[7] = scale*(vs - 1);
+ dst_t * y = yy + ii + 64*ib + 8*il;
-#undef COMPUTE_VS
-}
-
-template<typename dst_t>
-static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {
-
- const int64_t ii = blockIdx.x;
- const block_iq1_bn * x = (const block_iq1_bn *) vx;
+ int64_t row = ii / n_per_row;
+ if (row >= nrows) return;
+ const char * cx = (const char *)vx + row * row_size;
+ half d16; memcpy(&d16, cx, sizeof(d16)); // in case not 2-byte aligned
+ float d = d16;
+ const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(d16));
+ ii -= row*n_per_row;
+ int64_t i = ii/QK_IQ1BN + ib;
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
//#define COMPUTE_VS(v) 3*v >> 8
#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
- const int tid = threadIdx.x;
- const int il = tid/4; // 0...7
- const int ib = tid%4; // 0...3
- dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
- int64_t i = QK_K/QK_IQ1BN * ii + ib;
- if (i >= nb64) return;
const int i16 = il/2;
uint8_t q = x[i].ql[3*i16+2*(il%2)];
for (int j = 0; j < 5; ++j) {
uint8_t v = k_mult[j]*q;
int8_t vs = COMPUTE_VS(v);
- y[2*(il%2)+j] = vs - 1;
+ y[2*(il%2)+j] = d*(vs - 1);
}
q = x[i].ql[3*i16+1];
for (int j = 0; j < 2; ++j) {
uint8_t v = k_mult[3*(il%2)+j]*q;
int8_t vs = COMPUTE_VS(v);
- y[5*(1-(il%2))+j] = vs-1;
+ y[5*(1-(il%2))+j] = d*(vs-1);
}
uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
int8_t vs = COMPUTE_VS(v);
- y[7] = vs - 1;
+ y[7] = d*(vs - 1);
#undef COMPUTE_VS
}
template<typename dst_t>
-static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {
-
- const int64_t ii = blockIdx.x;
- const block_iq2_bn * x = (const block_iq2_bn *) vx;
+static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size, int64_t nrows) {
+ int64_t ii = 256*blockIdx.x;
const int64_t tid = threadIdx.x;
int64_t ib64 = tid%4; // 0...3
int64_t il = tid/4; // 0...7
- dst_t * y = yy + 256*ii + 64*ib64 + 2*il;
- int64_t i = 256/QK_IQ1BN * ii + ib64;
- if (i >= nb64) return;
- const float m = -1;
+ dst_t * y = yy + ii + 64*ib64 + 2*il;
+
+ int64_t row = ii / n_per_row;
+ if (row >= nrows) return;
+ const char * cx = (const char *)vx + row * row_size;
+ float d = *(const float *)cx;
+ const block_iq2_bn * x = (const block_iq2_bn *)(cx + sizeof(float));
+ ii -= row*n_per_row;
+ int64_t i = ii/QK_IQ1BN + ib64;
+ const float m = -d;
auto qs = x[i].qs + 2*il;
for (int j = 0; j < 2; ++j) {
- y[j+ 0] = ((qs[j] >> 0) & 3) + m;
- y[j+16] = ((qs[j] >> 2) & 3) + m;
- y[j+32] = ((qs[j] >> 4) & 3) + m;
- y[j+48] = ((qs[j] >> 6) & 3) + m;
+ y[j+ 0] = d * ((qs[j] >> 0) & 3) + m;
+ y[j+16] = d * ((qs[j] >> 2) & 3) + m;
+ y[j+32] = d * ((qs[j] >> 4) & 3) + m;
+ y[j+48] = d * ((qs[j] >> 6) & 3) + m;
}
}
@@ -857,14 +804,6 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n
}
template<typename dst_t>
-static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
- const int64_t k = nrows * n_per_row;
- const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row);
- const int nb = (k + 255) / 256;
- dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size);
-}
-
-template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
@@ -975,25 +914,17 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t
template<typename dst_t>
static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
- const int nb64 = k / QK_IQ1BN;
- const int nb = (k + 255) / 256;
- dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
-}
-
-template<typename dst_t>
-static void dequantize_row_iq1_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
- const int64_t k = nrows * n_per_row;
- const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row);
const int nb = (k + 255) / 256;
- dequantize_block_iq1_tn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
+ dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows);
}
template<typename dst_t>
static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
- const int nb64 = k / QK_IQ1BN;
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row);
const int nb = (k + 255) / 256;
- dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
+ dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows);
}
template<typename dst_t>
@@ -1157,8 +1088,6 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
- case GGML_TYPE_IQ2_TN:
- return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
@@ -1181,8 +1110,6 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
return dequantize_row_iq1_bn_cuda;
- case GGML_TYPE_IQ1_TN:
- return dequantize_row_iq1_tn_cuda;
case GGML_TYPE_IQ2_BN:
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
@@ -1232,8 +1159,6 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
- case GGML_TYPE_IQ2_TN:
- return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
@@ -1256,8 +1181,6 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
return dequantize_row_iq1_bn_cuda;
- case GGML_TYPE_IQ1_TN:
- return dequantize_row_iq1_tn_cuda;
case GGML_TYPE_IQ2_BN:
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 1dfb24f9..c15d6c81 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -38,6 +38,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
default:
+ fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
GGML_ABORT("fatal error");
break;
}
@@ -63,6 +64,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
// break;
default:
+ fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
GGML_ABORT("fatal error");
break;
}
@@ -86,6 +88,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
+ fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
GGML_ABORT("fatal error");
break;
}
@@ -114,6 +117,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
+ fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
GGML_ABORT("fatal error");
break;
}
@@ -141,6 +145,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
+ fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
GGML_ABORT("fatal error");
break;
}
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index dec54b5e..795243e7 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -626,35 +626,12 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
}
-#define VDR_IQ2_TN_Q8_1_MMVQ 1
-#define VDR_IQ2_TN_Q8_1_MMQ 4
-
-static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
+__device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- float scale = *(const float *)vbq;
- const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx;
-
- const int bq8_offset = QR2_K * (iqs / QI8_1);
-
- const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs;
- int v = q16[0] | (q16[1] << 16);
-
- float sumf = 0;
- for (int i = 0; i < QR2_K; ++ i) {
- int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
- float d8 = __low2float(bq8_1[bq8_offset + i].ds);
- sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
- v >>= 2;
- }
- return scale * sumf;
-}
-
-static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
-
- float scale = *(const half *)vbq;
- const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(half)) + kbx;
+ half d16; memcpy(&d16, vbq, sizeof(d16));
+ float scale = d16;
+ const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(d16)) + kbx;
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
@@ -699,7 +676,48 @@ static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(
q8++;
}
#endif
- return __low2float(bq8_1[iqs].ds) * scale * sumi;
+ return scale * __low2float(bq8_1[iqs].ds) * sumi;
+}
+
+__device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ float scale = *(const float *)vbq;
+ const block_iq2_bn * bq2 = (const block_iq2_bn *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0 or 1
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ auto qs = (const uint16_t *)bq2->qs + 4*iqs;
+ auto q8l = (const int *)bq8_1[0].qs + 2*iqs;
+ auto q8h = (const int *)bq8_1[1].qs + 2*iqs;
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+ for (int j = 0; j < 2; ++j) {
+ int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16);
+ int vh = vl >> 4;
+ sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1);
+ sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2);
+ sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3);
+ sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4);
+ }
+ auto d8l = __half22float2(bq8_1[0].ds);
+ auto d8h = __half22float2(bq8_1[1].ds);
+#else
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+ auto q8l = bq8_1[0].qs + 8*iqs;
+ auto q8h = bq8_1[1].qs + 8*iqs;
+ auto qs = bq2->qs + 8*iqs;
+ for (int j = 0; j < 8; ++j) {
+ sumi1 += q8l[j+ 0] * (qs[j] & 0x03);
+ sumi2 += q8l[j+16] * (qs[j] & 0x0c);
+ sumi3 += q8h[j+ 0] * (qs[j] & 0x30);
+ sumi4 += q8h[j+16] * (qs[j] & 0xc0);
+ }
+ auto d8l = __half22float2(bq8_1[0].ds);
+ auto d8h = __half22float2(bq8_1[1].ds);
+ return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
+#endif
+ return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
}
} // namespace
@@ -760,16 +778,14 @@ void mul_mat_vec_iq6_k_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
-void mul_mat_vec_iq2_tn_q8_1_cuda(
+void mul_mat_vec_iq1_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
-void mul_mat_vec_iq1_tn_q8_1_cuda(
+void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_TN, 1, vec_dot_iq1_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 0678c026..1693a73a 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -20,23 +20,22 @@ void mul_mat_vec_iq6_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
-void mul_mat_vec_iq2_tn_q8_1_cuda(
+void mul_mat_vec_iq4_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
-void mul_mat_vec_iq1_tn_q8_1_cuda(
+void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
-void mul_mat_vec_iq4_ks_q8_1_cuda(
+void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
-void mul_mat_vec_iq4_kss_q8_1_cuda(
+void mul_mat_vec_iq1_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
-void mul_mat_vec_iq2_ks_q8_1_cuda(
+void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
-
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 107caf45..cdf13533 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -22,8 +22,6 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
- type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 :
- type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 :
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
@@ -325,20 +323,6 @@ static void mul_mat_vec_iq1_m_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
-static void mul_mat_vec_iq1_bn_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_bn_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -438,12 +422,6 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ2_BN:
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
- case GGML_TYPE_IQ2_TN:
- mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
- break;
- case GGML_TYPE_IQ1_TN:
- mul_mat_vec_iq1_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
- break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index 7baabb7a..e9af29b9 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -1117,95 +1117,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
-static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
-
- static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
-
- // iqs is 0 or 1
-
- int sumi = 0;
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
- const int * q8 = (const int *)bq8_1[iqs].qs;
- int val[4];
- for (int l = 0; l < 2; ++l) {
- int8_t * a = (int8_t *)val;
- const int i16 = 2*iqs + l;
- for (int k = 0; k < 3; ++k) {
- uint8_t q = bq1->ql[3*i16+k];
- for (int j = 0; j < 5; ++j) {
- uint8_t v = k_mult[j]*q;
- int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
- *a++ = vs-1;
- }
- }
- uint8_t v = k_mult[i16]*bq1->extra;
- int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
- *a++ = vs-1;
- sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
- }
-#else
- const int8_t * q8 = bq8_1[iqs].qs;
- for (int l = 0; l < 2; ++l) {
- const int i16 = 2*iqs + l;
- for (int k = 0; k < 3; ++k) {
- uint8_t q = bq1->ql[3*i16+k];
- for (int j = 0; j < 5; ++j) {
- uint8_t v = k_mult[j]*q;
- int8_t vs = (v + (v >> 1)) >> 7;
- sumi += q8[j]*(vs - 1);
- }
- q8 += 5;
- }
- uint8_t v = k_mult[i16]*bq1->extra;
- int8_t vs = (v + (v >> 1)) >> 7;
- sumi += q8[0]*(vs - 1);
- q8++;
- }
-#endif
- return __low2float(bq8_1[iqs].ds) * sumi;
-}
-
-static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx;
-
- // iqs is 0 or 1
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
- auto qs = (const uint16_t *)bq2->qs + 4*iqs;
- auto q8l = (const int *)bq8_1[0].qs + 2*iqs;
- auto q8h = (const int *)bq8_1[1].qs + 2*iqs;
- int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
- for (int j = 0; j < 2; ++j) {
- int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16);
- int vh = vl >> 4;
- sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1);
- sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2);
- sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3);
- sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4);
- }
- auto d8l = __half22float2(bq8_1[0].ds);
- auto d8h = __half22float2(bq8_1[1].ds);
- return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
-#else
- int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
- auto q8l = bq8_1[0].qs + 8*iqs;
- auto q8h = bq8_1[1].qs + 8*iqs;
- auto qs = bq2->qs + 8*iqs;
- for (int j = 0; j < 8; ++j) {
- sumi1 += q8l[j+ 0] * (qs[j] & 0x03);
- sumi2 += q8l[j+16] * (qs[j] & 0x0c);
- sumi3 += q8h[j+ 0] * (qs[j] & 0x30);
- sumi4 += q8h[j+16] * (qs[j] & 0xc0);
- }
- auto d8l = __half22float2(bq8_1[0].ds);
- auto d8h = __half22float2(bq8_1[1].ds);
- return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
-#endif
-}
-
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;