summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu200
1 files changed, 189 insertions, 11 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 7695b86b..ac7d0ecb 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -191,6 +191,10 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
+static __device__ __forceinline__ int __vsub4(const int a, const int b) {
+ return __vsubss4(a, b);
+}
+
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -505,6 +509,14 @@ typedef struct {
} block_iq2_xs;
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
+#define QR3_XXS 8
+#define QI3_XXS (QK_K / (4*QR3_XXS))
+typedef struct {
+ half d;
+ uint8_t qs[3*(QK_K/8)];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@@ -1613,6 +1625,41 @@ static const __device__ uint64_t iq2xs_grid[512] = {
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
};
+static const __device__ uint32_t iq3xxs_grid[256] = {
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+};
+
static const __device__ uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -1624,6 +1671,43 @@ static const __device__ uint8_t ksigns_iq2xs[128] = {
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
};
+//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+static const __device__ uint64_t ksigns64[128] = {
+ 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
+ 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
+ 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
+ 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
+ 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
+ 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
+ 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
+ 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
+ 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
+ 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
+ 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
+ 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
+ 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
+ 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
+ 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
+ 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
+ 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
+ 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
+ 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
+ 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
+ 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
+ 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
+ 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
+ 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
+ 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
+ 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
+ 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
+ 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
+ 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
+ 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
+ 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
+ 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
+};
+//#endif
+
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
@@ -1690,6 +1774,34 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
}
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int i = blockIdx.x;
+ const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
+
+ const int tid = threadIdx.x;
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * q3 = x[i].qs + 8*ib;
+ const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+#else
+ assert(false);
+#endif
+
+}
+
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -4313,6 +4425,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
#if QK_K == 256
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
@@ -4323,20 +4436,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
const uint8_t ls2 = bq2->scales[ib32] >> 4;
int sumi1 = 0;
for (int l = 0; l < 2; ++l) {
- const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
- const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
- for (int j = 0; j < 8; ++j) {
- sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
- }
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+ const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
+ sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
+ sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
q8 += 8;
}
int sumi2 = 0;
for (int l = 2; l < 4; ++l) {
- const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
- const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
- for (int j = 0; j < 8; ++j) {
- sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
- }
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+ const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
+ sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
+ sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
q8 += 8;
}
const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
@@ -4345,6 +4460,45 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
assert(false);
return 0.f;
#endif
+#else
+ assert(false);
+ return 0.f;
+#endif
+}
+
+static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+ const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
+
+ const int ib32 = iqs;
+ const uint8_t * q3 = bq2->qs + 8*ib32;
+ const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ uint32_t aux32 = gas[0] | (gas[1] << 16);
+ int sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
+ const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
+ const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);
+ sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
+ sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
+ q8 += 8;
+ aux32 >>= 7;
+ }
+ const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
+ return d * sumi;
+#else
+ assert(false);
+ return 0.f;
+#endif
+#else
+ assert(false);
+ return 0.f;
+#endif
}
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
@@ -6381,6 +6535,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
+template<typename dst_t>
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6418,6 +6578,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
default:
@@ -6451,6 +6613,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
default:
@@ -6663,6 +6827,15 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
+static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -8213,6 +8386,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
@@ -8235,6 +8409,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
@@ -8306,6 +8481,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ2_XS:
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
default:
GGML_ASSERT(false);
break;
@@ -10934,7 +11112,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return false;
}
ggml_type a_type = a->type;
- if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
+ if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}