diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-25 13:08:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-25 13:08:43 +0200 |
commit | 6b968f38946117552ffed300771c44ba9b39d3e4 (patch) | |
tree | dc6b0df69f31ea77d9941d6798a4ef411c688080 | |
parent | 9114078959b404899fd67e1af45f0dcbee51b47f (diff) |
Bitnet changes (#106)
* Adapting iq2_bn to work without separate scale tensors
Why? It is becoming burdensome to maintain the special Bitnet
conversion in convert_hf_to_gguf.py, so I thnk it is better
to make iq1_bn and iq2_bn just work with the mainline
conversion script (which does not generate scales).
* Adapting iq1_bn to work without separate scale tensors
* Adapting iq2_bn: CUDA dequantize
* Adapting iq2_bn: CUDA works
* Adapting iq1_bn: CUDA works
* Adapting iq1_bn, iq2_bn: NEON
* Adapting iq1_bn, iq2_bn: Metal
Dequantize works, but there is still something wrong
with the dot products.
* WIP
Absoolutely don't see what is wrong with the iq1_bn and iq2_bn
vector dot product kernels.
* Remove iq1_tn and iq2_tn - Part 1
Now that iq1_bn and iq2_bn have per row scales, there is no
reason to also have iq1_tn and iq2_tn.
* Remove iq1_tn and iq2_tn - Part 2
* Bitnet: use the standard llm_build_kv to build self attention
My main motivation was to enable FA. But FA does not work anyway
because head size is 100 for the Botnet ternary models
(and I had forgotten this little detail).
* Revert "Avoid rebuild of GGML graph for each token (#98)"
This reverts commit f2d315b46f7aacc7df4b86bd8acba387b30e11ca.
As far as I can tell, the commit breaks Metal TG.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 2 | ||||
-rw-r--r-- | ggml/include/ggml-backend.h | 6 | ||||
-rw-r--r-- | ggml/include/ggml.h | 11 | ||||
-rw-r--r-- | ggml/src/ggml-backend.c | 45 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 17 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 2 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/binbcast.cu | 2 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 147 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 84 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 11 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 22 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 89 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 56 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 416 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 2 | ||||
-rw-r--r-- | ggml/src/ggml.c | 46 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 487 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 209 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 12 | ||||
-rw-r--r-- | include/llama.h | 2 | ||||
-rw-r--r-- | src/llama.cpp | 202 |
23 files changed, 274 insertions, 1615 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c88033b6..b5907e2b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -29,8 +29,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, - { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.63 bpw quantization (TriLM)", }, - { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.00 bpw quantization (TriLM)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 621620bc..5f3f1e28 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -232,12 +232,6 @@ extern "C" { GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); - // Utility to query whether cached GGML graph is in use - GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched); - - // Set whether or not to use GGML graph caching - GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value); - #ifdef __cplusplus } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a99dc6b5..5ba77012 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -401,8 +401,8 @@ extern "C" { GGML_TYPE_IQ4_K = 139, GGML_TYPE_IQ5_K = 140, GGML_TYPE_IQ6_K = 141, - GGML_TYPE_IQ2_TN = 142, - GGML_TYPE_IQ1_TN = 143, + // depricated: GGML_TYPE_IQ2_TN = 142, + // depricated: GGML_TYPE_IQ1_TN = 143, GGML_TYPE_IQ4_KS = 144, GGML_TYPE_IQ2_KS = 145, GGML_TYPE_IQ4_KSS = 146, @@ -597,13 +597,6 @@ extern "C" { GGML_TENSOR_FLAG_PARAM = 4, }; - // Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache - enum ggml_kv_cache_flag { - GGML_KV_CACHE_FLAG_NONE = 0, - GGML_KV_CACHE_FLAG_K = 1, - GGML_KV_CACHE_FLAG_V = 2 - }; - // ggml object struct ggml_object { size_t offs; diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 76d37f74..e1651cc6 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -1040,13 +1040,6 @@ struct ggml_backend_sched_split { struct ggml_cgraph graph; }; -// Object to facilitate GML graph caching -struct ggml_cached_graph { - bool is_active; - ggml_backend_t input_backend; - struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS]; -}; - struct ggml_backend_sched { bool is_reset; // true if the scheduler has been reset since the last graph split bool is_alloc; @@ -1092,8 +1085,6 @@ struct ggml_backend_sched { size_t context_buffer_size; bool debug; - - struct ggml_cached_graph cached_graph; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1771,14 +1762,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_tensor * input = split->inputs[j]; struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); - if (!sched->cached_graph.is_active) { - sched->cached_graph.input_backend = input_backend; - sched->cached_graph.input_cpy[j] = input_cpy; - } else { - input_backend = sched->cached_graph.input_backend; - input_cpy = sched->cached_graph.input_cpy[j]; - } - if (input->flags & GGML_TENSOR_FLAG_INPUT) { // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { @@ -1910,8 +1893,6 @@ ggml_backend_sched_t ggml_backend_sched_new( ggml_backend_sched_reset(sched); - sched->cached_graph.is_active = false; - return sched; } @@ -1988,16 +1969,16 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st } enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - if(!sched->cached_graph.is_active) { - if (!sched->is_reset && !sched->is_alloc) { - ggml_backend_sched_reset(sched); - } - if (!sched->is_alloc) { - if (!ggml_backend_sched_alloc_graph(sched, graph)) { - return GGML_STATUS_ALLOC_FAILED; - } + if (!sched->is_reset && !sched->is_alloc) { + ggml_backend_sched_reset(sched); + } + + if (!sched->is_alloc) { + if (!ggml_backend_sched_alloc_graph(sched, graph)) { + return GGML_STATUS_ALLOC_FAILED; } } + return ggml_backend_sched_compute_splits(sched); } @@ -2262,13 +2243,3 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } - -bool ggml_use_cached_graph(ggml_backend_sched_t sched) { - return sched->cached_graph.is_active; -} - -void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) { - sched->cached_graph.is_active = set_value; -} - - diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f8824b0e..f0c1ae68 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -389,9 +389,7 @@ typedef struct { static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); // -// Bitnet - implemented as 1.625 bpw -// The block scale is a waste, but it allows us to plug it in without any additional -// changes to ggml. +// Bitnet and TriLM - implemented as 1.625 bpw // #define QK_IQ1BN 64 typedef struct { @@ -400,24 +398,13 @@ typedef struct { } block_iq1_bn; static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding"); // -// Bitnet - implemented as 2.0 bpw +// Bitnet and TriLM - implemented as 2.0 bpw // #define QK_IQ2BN 64 typedef struct { uint8_t qs[QK_IQ2BN/4]; } block_iq2_bn; static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding"); -// -// TriLM - implemented as 2.0625 bpw -// -typedef struct { - uint8_t qs[52]; -} block_iq1_tn; -static_assert(sizeof(block_iq1_tn) == 52, "wrong iq1_tn block size/padding"); -typedef struct { - uint8_t qs[QK_K/4]; -} block_iq2_tn; -static_assert(sizeof(block_iq2_tn) == QK_K/4, "wrong iqt_bn block size/padding"); // Used by IQ1_M quants typedef union { diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9051863b..6759e202 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2841,9 +2841,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ1_BN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: return true; default: return false; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 62d115f1..5abbd43c 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -288,7 +288,7 @@ static void scale_f32_cuda_l(const float * x, float * dst, const void * data, co scale_f32_l<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, data, k); } -void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +static void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a5658a24..2eba527f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -474,13 +474,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> { }; template<> -struct ggml_cuda_type_traits<GGML_TYPE_IQ1_TN> { - static constexpr int qk = QK_IQ1BN; - static constexpr int qr = QR1_BN; - static constexpr int qi = QI1_BN; -}; - -template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> { static constexpr int qk = QK_IQ1BN; static constexpr int qr = QR1_BN; @@ -488,13 +481,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> { }; template<> -struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> { - static constexpr int qk = QK_K; - static constexpr int qr = QR2_K; - static constexpr int qi = QI2_K; -}; - -template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> { static constexpr int qk = QK4_NL; static constexpr int qr = QR4_NL; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e9d15b5d..b9baee1b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -184,30 +184,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t } template<typename dst_t> -static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy, - int64_t n_per_row, int64_t row_size) { - - int64_t ii = blockIdx.x; - int64_t row = (QK_K * ii) / n_per_row; - const char * cx = (const char *)vx + row * row_size; - float d = *(const float *)cx; - const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float)); - int64_t i = ii - (row*n_per_row)/QK_K; - - const int64_t tid = threadIdx.x; - const int64_t n = tid/32; - const int64_t l = tid - 32*n; - - const uint8_t q = x[i].qs[32*n + l]; - dst_t * y = yy + ii*QK_K + 128*n; - - y[l+ 0] = d * ((q >> 0) & 3) - d; - y[l+32] = d * ((q >> 2) & 3) - d; - y[l+64] = d * ((q >> 4) & 3) - d; - y[l+96] = d * ((q >> 6) & 3) - d; -} - -template<typename dst_t> static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -481,101 +457,72 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ } template<typename dst_t> -static __global__ void dequantize_block_iq1_tn(const void * __restrict__ vx, dst_t * __restrict__ yy, - int64_t n_per_row, int64_t row_size) { - - int64_t ii = blockIdx.x; - int64_t row = (QK_K * ii) / n_per_row; - const char * cx = (const char *)vx + row * row_size; - float scale = *(const half *)cx; - const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(half)); - - static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - -//#define COMPUTE_VS(v) 3*v >> 8 -#define COMPUTE_VS(v) (v + (v >> 1)) >> 7 +static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, + int64_t n_per_row, int64_t row_size, int64_t nrows) { + int64_t ii = 256*blockIdx.x; const int tid = threadIdx.x; const int il = tid/4; // 0...7 const int ib = tid%4; // 0...3 - dst_t * y = yy + ii*QK_K + 64*ib + 8*il; - const int i16 = il/2; - int64_t i = QK_K/QK_IQ1BN * (ii - (row*n_per_row)/QK_K) + ib; - uint8_t q = x[i].ql[3*i16+2*(il%2)]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = COMPUTE_VS(v); - y[2*(il%2)+j] = scale*(vs - 1); - } - q = x[i].ql[3*i16+1]; - for (int j = 0; j < 2; ++j) { - uint8_t v = k_mult[3*(il%2)+j]*q; - int8_t vs = COMPUTE_VS(v); - y[5*(1-(il%2))+j] = scale*(vs-1); - } - uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; - int8_t vs = COMPUTE_VS(v); - y[7] = scale*(vs - 1); + dst_t * y = yy + ii + 64*ib + 8*il; -#undef COMPUTE_VS -} - -template<typename dst_t> -static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { - - const int64_t ii = blockIdx.x; - const block_iq1_bn * x = (const block_iq1_bn *) vx; + int64_t row = ii / n_per_row; + if (row >= nrows) return; + const char * cx = (const char *)vx + row * row_size; + half d16; memcpy(&d16, cx, sizeof(d16)); // in case not 2-byte aligned + float d = d16; + const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(d16)); + ii -= row*n_per_row; + int64_t i = ii/QK_IQ1BN + ib; static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; //#define COMPUTE_VS(v) 3*v >> 8 #define COMPUTE_VS(v) (v + (v >> 1)) >> 7 - const int tid = threadIdx.x; - const int il = tid/4; // 0...7 - const int ib = tid%4; // 0...3 - dst_t * y = yy + ii*QK_K + 64*ib + 8*il; - int64_t i = QK_K/QK_IQ1BN * ii + ib; - if (i >= nb64) return; const int i16 = il/2; uint8_t q = x[i].ql[3*i16+2*(il%2)]; for (int j = 0; j < 5; ++j) { uint8_t v = k_mult[j]*q; int8_t vs = COMPUTE_VS(v); - y[2*(il%2)+j] = vs - 1; + y[2*(il%2)+j] = d*(vs - 1); } q = x[i].ql[3*i16+1]; for (int j = 0; j < 2; ++j) { uint8_t v = k_mult[3*(il%2)+j]*q; int8_t vs = COMPUTE_VS(v); - y[5*(1-(il%2))+j] = vs-1; + y[5*(1-(il%2))+j] = d*(vs-1); } uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; int8_t vs = COMPUTE_VS(v); - y[7] = vs - 1; + y[7] = d*(vs - 1); #undef COMPUTE_VS } template<typename dst_t> -static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { - - const int64_t ii = blockIdx.x; - const block_iq2_bn * x = (const block_iq2_bn *) vx; +static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size, int64_t nrows) { + int64_t ii = 256*blockIdx.x; const int64_t tid = threadIdx.x; int64_t ib64 = tid%4; // 0...3 int64_t il = tid/4; // 0...7 - dst_t * y = yy + 256*ii + 64*ib64 + 2*il; - int64_t i = 256/QK_IQ1BN * ii + ib64; - if (i >= nb64) return; - const float m = -1; + dst_t * y = yy + ii + 64*ib64 + 2*il; + + int64_t row = ii / n_per_row; + if (row >= nrows) return; + const char * cx = (const char *)vx + row * row_size; + float d = *(const float *)cx; + const block_iq2_bn * x = (const block_iq2_bn *)(cx + sizeof(float)); + ii -= row*n_per_row; + int64_t i = ii/QK_IQ1BN + ib64; + const float m = -d; auto qs = x[i].qs + 2*il; for (int j = 0; j < 2; ++j) { - y[j+ 0] = ((qs[j] >> 0) & 3) + m; - y[j+16] = ((qs[j] >> 2) & 3) + m; - y[j+32] = ((qs[j] >> 4) & 3) + m; - y[j+48] = ((qs[j] >> 6) & 3) + m; + y[j+ 0] = d * ((qs[j] >> 0) & 3) + m; + y[j+16] = d * ((qs[j] >> 2) & 3) + m; + y[j+32] = d * ((qs[j] >> 4) & 3) + m; + y[j+48] = d * ((qs[j] >> 6) & 3) + m; } } @@ -857,14 +804,6 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n } template<typename dst_t> -static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { - const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); - const int nb = (k + 255) / 256; - dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size); -} - -template<typename dst_t> static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int nb = k / QK_K; @@ -975,25 +914,17 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t template<typename dst_t> static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int nb64 = k / QK_IQ1BN; - const int nb = (k + 255) / 256; - dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64); -} - -template<typename dst_t> -static void dequantize_row_iq1_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { - const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row); const int nb = (k + 255) / 256; - dequantize_block_iq1_tn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size); + dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows); } template<typename dst_t> static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int nb64 = k / QK_IQ1BN; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row); const int nb = (k + 255) / 256; - dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64); + dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows); } template<typename dst_t> @@ -1157,8 +1088,6 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; - case GGML_TYPE_IQ2_TN: - return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -1181,8 +1110,6 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: return dequantize_row_iq1_bn_cuda; - case GGML_TYPE_IQ1_TN: - return dequantize_row_iq1_tn_cuda; case GGML_TYPE_IQ2_BN: return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: @@ -1232,8 +1159,6 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; - case GGML_TYPE_IQ2_TN: - return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -1256,8 +1181,6 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: return dequantize_row_iq1_bn_cuda; - case GGML_TYPE_IQ1_TN: - return dequantize_row_iq1_tn_cuda; case GGML_TYPE_IQ2_BN: return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1dfb24f9..c15d6c81 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -38,6 +38,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -63,6 +64,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); // break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -86,6 +88,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -114,6 +117,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -141,6 +145,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index dec54b5e..795243e7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -626,35 +626,12 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( } -#define VDR_IQ2_TN_Q8_1_MMVQ 1 -#define VDR_IQ2_TN_Q8_1_MMQ 4 - -static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( +__device__ __forceinline__ float vec_dot_iq1_bn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - float scale = *(const float *)vbq; - const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx; - - const int bq8_offset = QR2_K * (iqs / QI8_1); - - const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs; - int v = q16[0] | (q16[1] << 16); - - float sumf = 0; - for (int i = 0; i < QR2_K; ++ i) { - int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); - float d8 = __low2float(bq8_1[bq8_offset + i].ds); - sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0)); - v >>= 2; - } - return scale * sumf; -} - -static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - - float scale = *(const half *)vbq; - const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(half)) + kbx; + half d16; memcpy(&d16, vbq, sizeof(d16)); + float scale = d16; + const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(d16)) + kbx; static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; @@ -699,7 +676,48 @@ static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1( q8++; } #endif - return __low2float(bq8_1[iqs].ds) * scale * sumi; + return scale * __low2float(bq8_1[iqs].ds) * sumi; +} + +__device__ __forceinline__ float vec_dot_iq2_bn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + float scale = *(const float *)vbq; + const block_iq2_bn * bq2 = (const block_iq2_bn *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0 or 1 + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + auto qs = (const uint16_t *)bq2->qs + 4*iqs; + auto q8l = (const int *)bq8_1[0].qs + 2*iqs; + auto q8h = (const int *)bq8_1[1].qs + 2*iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + for (int j = 0; j < 2; ++j) { + int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16); + int vh = vl >> 4; + sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1); + sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2); + sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3); + sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); +#else + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + auto q8l = bq8_1[0].qs + 8*iqs; + auto q8h = bq8_1[1].qs + 8*iqs; + auto qs = bq2->qs + 8*iqs; + for (int j = 0; j < 8; ++j) { + sumi1 += q8l[j+ 0] * (qs[j] & 0x03); + sumi2 += q8l[j+16] * (qs[j] & 0x0c); + sumi3 += q8h[j+ 0] * (qs[j] & 0x30); + sumi4 += q8h[j+16] * (qs[j] & 0xc0); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); + return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y); +#endif + return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y); } } // namespace @@ -760,16 +778,14 @@ void mul_mat_vec_iq6_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -void mul_mat_vec_iq2_tn_q8_1_cuda( +void mul_mat_vec_iq1_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -void mul_mat_vec_iq1_tn_q8_1_cuda( +void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_TN, 1, vec_dot_iq1_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 0678c026..1693a73a 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -20,23 +20,22 @@ void mul_mat_vec_iq6_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq2_tn_q8_1_cuda( +void mul_mat_vec_iq4_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq1_tn_q8_1_cuda( +void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq4_ks_q8_1_cuda( +void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq4_kss_q8_1_cuda( +void mul_mat_vec_iq1_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq2_ks_q8_1_cuda( +void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); - diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 107caf45..cdf13533 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -22,8 +22,6 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 : - type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : @@ -325,20 +323,6 @@ static void mul_mat_vec_iq1_m_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -static void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -438,12 +422,6 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ2_BN: mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; - case GGML_TYPE_IQ2_TN: - mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_TN: - mul_mat_vec_iq1_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 7baabb7a..e9af29b9 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1117,95 +1117,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } -static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; - - static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - - // iqs is 0 or 1 - - int sumi = 0; -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const int * q8 = (const int *)bq8_1[iqs].qs; - int val[4]; - for (int l = 0; l < 2; ++l) { - int8_t * a = (int8_t *)val; - const int i16 = 2*iqs + l; - for (int k = 0; k < 3; ++k) { - uint8_t q = bq1->ql[3*i16+k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; - *a++ = vs-1; - } - } - uint8_t v = k_mult[i16]*bq1->extra; - int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; - *a++ = vs-1; - sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi)))); - } -#else - const int8_t * q8 = bq8_1[iqs].qs; - for (int l = 0; l < 2; ++l) { - const int i16 = 2*iqs + l; - for (int k = 0; k < 3; ++k) { - uint8_t q = bq1->ql[3*i16+k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = (v + (v >> 1)) >> 7; - sumi += q8[j]*(vs - 1); - } - q8 += 5; - } - uint8_t v = k_mult[i16]*bq1->extra; - int8_t vs = (v + (v >> 1)) >> 7; - sumi += q8[0]*(vs - 1); - q8++; - } -#endif - return __low2float(bq8_1[iqs].ds) * sumi; -} - -static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx; - - // iqs is 0 or 1 - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - auto qs = (const uint16_t *)bq2->qs + 4*iqs; - auto q8l = (const int *)bq8_1[0].qs + 2*iqs; - auto q8h = (const int *)bq8_1[1].qs + 2*iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - for (int j = 0; j < 2; ++j) { - int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16); - int vh = vl >> 4; - sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1); - sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2); - sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3); - sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4); - } - auto d8l = __half22float2(bq8_1[0].ds); - auto d8h = __half22float2(bq8_1[1].ds); - return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y; -#else - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - auto q8l = bq8_1[0].qs + 8*iqs; - auto q8h = bq8_1[1].qs + 8*iqs; - auto qs = bq2->qs + 8*iqs; - for (int j = 0; j < 8; ++j) { - sumi1 += q8l[j+ 0] * (qs[j] & 0x03); - sumi2 += q8l[j+16] * (qs[j] & 0x0c); - sumi3 += q8h[j+ 0] * (qs[j] & 0x30); - sumi4 += q8h[j+16] * (qs[j] & 0xc0); - } - auto d8l = __half22float2(bq8_1[0].ds); - auto d8h = __half22float2(bq8_1[1].ds); - return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y; -#endif -} - static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; const int8_t * q0_8 = (const int8_t *) &q0_32; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 9f696383..8d350aa1 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -101,9 +101,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_TN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, @@ -145,9 +143,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, @@ -183,9 +179,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, @@ -218,9 +212,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, @@ -253,9 +245,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, @@ -649,9 +639,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_TN, get_rows_iq1_tn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, get_rows_iq2_tn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true); @@ -693,9 +681,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_TN_F32, mul_mv_iq1_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, mul_mv_iq2_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction); @@ -731,9 +717,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_TN_F32, mul_mv_id_iq1_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, mul_mv_id_iq2_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction); @@ -766,9 +750,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_TN_F32, mul_mm_iq1_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, mul_mm_iq2_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm); @@ -801,9 +783,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_TN_F32, mul_mm_id_iq1_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, mul_mm_id_iq2_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm); @@ -2001,9 +1981,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ1_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_TN_F32 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; @@ -2197,24 +2175,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline; } break; - case GGML_TYPE_IQ1_TN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_TN_F32].pipeline; - } break; case GGML_TYPE_IQ2_BN: { nth0 = 4; nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32].pipeline; } break; - case GGML_TYPE_IQ2_TN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32].pipeline; - } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -2306,8 +2272,7 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0 || - src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) { + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2417,9 +2382,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ1_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_TN_F32 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; @@ -2601,24 +2564,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; } break; - case GGML_TYPE_IQ1_TN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_TN_F32].pipeline; - } break; case GGML_TYPE_IQ2_BN: { nth0 = 4; nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; } break; - case GGML_TYPE_IQ2_TN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32].pipeline; - } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -2721,8 +2672,7 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| - src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) { + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2790,9 +2740,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; - case GGML_TYPE_IQ1_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_TN ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; - case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 8981cda9..bc0ea9f5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3948,128 +3948,6 @@ kernel void kernel_mul_mv_q2_K_f32( kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -void kernel_mul_mv_iq2_tn_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int row_size = nb*sizeof(block_iq2_tn) + 4; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size; - - device const char * cx = (device const char *) src0 + first_row*row_size + offset0; - device const float * y = (device const float*) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - float drow[N_DST]; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - - for (int row = 0; row < N_DST; row++) drow[row] = *((device const float *)(cx + row*row_size)); - - for (int ib = ix; ib < nb; ib += 4) { - - float sumy = 0.f; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy += yl[i+16]; - yl[i+24] = y4[i+96]; sumy += yl[i+24]; - } - - device const block_iq2_tn * x = (device const block_iq2_tn *)(cx + 4); - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - sumf[row] += (acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy; - - qs += row_size/2; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = drow[row]*all_sum; - } - } -} - -[[host_name("kernel_mul_mv_iq2_tn_f32")]] -kernel void kernel_mul_mv_iq2_tn_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_tn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, @@ -5528,49 +5406,6 @@ static inline float iq1bn_fp8_to_float(uint8_t fp8) { return s.f; } -//static constant int8_t iq1bn_values[256*5] = { -// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0, -// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1, -// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1, -// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1, -// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1, -// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1, -// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0, -// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1, -// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1, -// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1, -// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0, -// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1, -// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1, -// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1, -// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0, -// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1, -// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1, -// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0, -// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1, -// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1, -// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0, -// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0, -// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0, -// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0, -// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0, -// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1, -// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0, -// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0, -// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1, -// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1, -// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0, -// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1, -// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, -// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1, -// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1, -// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1, -// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, -// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1, -// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, -//}; - void kernel_mul_mv_iq1_bn_f32_impl( device const void * src0, device const float * src1, @@ -5595,37 +5430,47 @@ void kernel_mul_mv_iq1_bn_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; + const int row_size = nb*sizeof(block_iq1_bn) + 2; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size; + device const uint8_t * cx = (device const uint8_t *) src0 + first_row*row_size + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[16]; float sumf[N_DST]={0.f}; + float scale[N_DST]; const int nb32 = nb * (QK_IQ1BN / 32); const int ix = tiisg/2; const int ir = tiisg%2; + for (int row = 0; row < N_DST; ++row) { + half d16; + thread uint8_t * aux = (thread uint8_t *)&d16; + device const uint8_t * cr = cx + row*row_size; + aux[0] = cr[0]; aux[1] = cr[1]; + scale[row] = d16; + } + device const block_iq1_bn * x = (device const block_iq1_bn *)(cx + 2); + device const float * y4 = (device const float *)y + 32 * ix + 16 * ir; const float values[3] = {-1.f, 0.f, 1.f}; constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + const int ib = ix % (QK_IQ1BN / 32); + const int i16 = 2*ib + ir; + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { for (int j = 0; j < 16; ++j) yl[j] = y4[j]; const int ibl = ib32 / (QK_IQ1BN / 32); - const int ib = ib32 % (QK_IQ1BN / 32); - const int i16 = 2*ib + ir; - device const block_iq1_bn * xr = x + ibl; device const uint8_t * ql = xr->ql + 3*i16; device const uint8_t * extra = (device const uint8_t *)&xr->extra; @@ -5635,8 +5480,6 @@ void kernel_mul_mv_iq1_bn_f32_impl( float acc = 0; int i = 0; for (int k = 0; k < 3; ++k) { - //constant int8_t * vs = iq1bn_values + 5*ql[k]; - //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j]; uint8_t q = ql[k]; for (int j = 0; j < 5; ++j) { uint8_t v = k_mult[j]*q; @@ -5644,32 +5487,29 @@ void kernel_mul_mv_iq1_bn_f32_impl( acc += yl[i++] * values[v]; } } - //constant int8_t * vs = iq1bn_values + 5*extra[0]; - //acc += yl[15] * vs[i16]; uint8_t v = k_mult[i16]*extra[0]; v = 3*v >> 8; //(v + (v >> 1)) >> 7; acc += yl[15] * values[v]; sumf[row] += acc; - extra += nb*sizeof(block_iq1_bn); - ql += nb*sizeof(block_iq1_bn); + extra += row_size; + ql += row_size; } y4 += 32 * 16; } for (int row = 0; row < N_DST; row += 2) { - half2 r = {(half)sumf[row], (half)sumf[row+1]}; + float2 r = {sumf[row], sumf[row+1]}; r = simd_sum(r); if (tiisg < 2) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg]; + dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg] * scale[row + tiisg]; } } } -// TODO: unify with kernel_mul_mv_iq1_bn_f32_impl -void kernel_mul_mv_iq1_tn_f32_impl( +void kernel_mul_mv_iq2_bn_f32_impl( device const void * src0, device const float * src1, device float * dst, @@ -5692,157 +5532,65 @@ void kernel_mul_mv_iq1_tn_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - // Why are we not passing in src0->nb[0]? - // But because we are not, we need to use this hack - const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K); + const int row_size = nb*sizeof(block_iq2_bn) + sizeof(float); const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*(ne01*ne02))*row_size; - device const char * cx = (device const char *) src0 + first_row*row_size + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float sumf[N_DST]={0.f}; - - const int nb32 = nb * (QK_IQ1BN / 32); - - const int ix = tiisg/2; - const int ir = tiisg%2; - - device const float * y4 = (device const float *)y + 32 * ix + 16 * ir; - - const float values[3] = {-1.f, 0.f, 1.f}; - - constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - - for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - - for (int j = 0; j < 16; ++j) yl[j] = y4[j]; - - const int ibl = ib32 / (QK_IQ1BN / 32); - const int ib = ib32 % (QK_IQ1BN / 32); - const int i16 = 2*ib + ir; - - device const half * dh = (device const half *)cx; - device const block_iq1_bn * xr = (device const block_iq1_bn *)(dh + 1) + ibl; - device const uint8_t * ql = xr->ql + 3*i16; - device const uint8_t * extra = (device const uint8_t *)&xr->extra; - - for (int row = 0; row < N_DST; row++) { - - float acc = 0; - int i = 0; - for (int k = 0; k < 3; ++k) { - uint8_t q = ql[k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - v = 3*v >> 8; //(v + (v >> 1)) >> 7; - acc += yl[i++] * values[v]; - } - } - uint8_t v = k_mult[i16]*extra[0]; - v = 3*v >> 8; //(v + (v >> 1)) >> 7; - acc += yl[15] * values[v]; - - sumf[row] += acc * (float)dh[0]; + const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size; - extra += row_size; - ql += row_size; - dh += row_size/2; - } + device const uint8_t * cx = (device const uint8_t *) src0 + first_row*row_size + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - y4 += 32 * 16; - } + float4 yl[4]; + float sumf[N_DST]={0.f}; + float scale[N_DST]; - for (int row = 0; row < N_DST; row += 2) { - half2 r = {(half)sumf[row], (half)sumf[row+1]}; - r = simd_sum(r); - if (tiisg < 2) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg]; - } + for (int row = 0; row < N_DST; ++row) { + scale[row] = *((device const float *)(cx + row*row_size)); } -} - -void kernel_mul_mv_iq2_bn_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_IQ1BN; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq2_bn * x = (device const block_iq2_bn *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float sumf[N_DST]={0.f}; const int ix = tiisg/4; // 0...7 const int ir = tiisg%4; // 0...3 - device const float * y4 = y + 64 * ix + 4 * ir; + device const float4 * y4 = (device const float4 *)(y + QK_IQ1BN * ix + 4 * ir); + device const uint8_t * qs0 = cx + sizeof(float) + (QK_IQ1BN/4)*ix + 4*ir; for (int ib = ix; ib < nb; ib += 8) { - float sumy = 0.f; - for (int i = 0; i < 4; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; - yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4]; - yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; - yl[i+12] = y4[i+48]; sumy += yl[i+12]; - } + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[8]; yl[3] = y4[12]; + float4 tmp = yl[0] + yl[1] + yl[2] + yl[3]; + const float sumy = tmp[0] + tmp[1] + tmp[2] + tmp[3]; - device const uint8_t * qs = x[ib].qs + 4*ir; + device const uint8_t * qs = qs0; for (int row = 0; row < N_DST; row++) { float4 acc = {0.f}; + for (int j = 0; j < 4; ++j) { - acc[0] += yl[j+ 0] * (qs[j] & 0x03); - acc[1] += yl[j+ 4] * (qs[j] & 0x0c); - acc[2] += yl[j+ 8] * (qs[j] & 0x30); - acc[3] += yl[j+12] * (qs[j] & 0xc0); + acc[0] += yl[0][j] * (qs[j] & 0x03); + acc[1] += yl[1][j] * (qs[j] & 0x0c); + acc[2] += yl[2][j] * (qs[j] & 0x30); + acc[3] += yl[3][j] * (qs[j] & 0xc0); } sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy; - qs += nb*sizeof(block_iq2_bn); + qs += row_size; } - y4 += 64 * 8; + y4 += QK_IQ1BN * 2; + qs0 += QK_IQ1BN * 2; } for (int row = 0; row < N_DST; row += 2) { - half2 r = {(half)sumf[row], (half)sumf[row+1]}; + float2 r = {sumf[row], sumf[row+1]}; r = simd_sum(r); if (tiisg < 2) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg]; + dst[r1*ne0 + im*ne0*ne1 + first_row + row+tiisg] = r[tiisg] * scale[row+tiisg]; } } } @@ -7066,34 +6814,6 @@ kernel void kernel_mul_mv_iq1_bn_f32( kernel_mul_mv_iq1_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -[[host_name("kernel_mul_mv_iq1_tn_f32")]] -kernel void kernel_mul_mv_iq1_tn_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_tn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - [[host_name("kernel_mul_mv_iq2_bn_f32")]] kernel void kernel_mul_mv_iq2_bn_f32( device const void * src0, @@ -7486,19 +7206,6 @@ void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & re } template <typename type4x4> -void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) { - device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1); - - il = (il/2)%4; - - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * (q[i] & mask) - 1; - } -} - -template <typename type4x4> void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { const half d_all = xb->d; device const uint8_t * q = (device const uint8_t *)xb->qs; @@ -8145,11 +7852,16 @@ struct DefaultDequantizer { short il; }; -template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&)> +template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&), bool may_not_be_aligned = false> struct DequantizerRS{ using type4x4 = T4x4; DequantizerRS(device const char * cx, short il = 0) : il(il) { - d = *(device const Scale *)cx; + if (may_not_be_aligned) { + thread char * aux = (thread char *)&d; + for (int i = 0; i < sizeof(d); ++i) aux[i] = cx[i]; + } else { + d = *(device const Scale *)cx; + } x = (device const Block *)(cx + sizeof(Scale)); } inline void convert(thread T4x4& t) const { @@ -8537,10 +8249,8 @@ template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq4_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_k, QK_NL, dequantize_iq4_k>; template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq5_k, QK_NL, dequantize_iq5_k>; template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>; -template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>; -template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>; -template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; -template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; +template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; +template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>; template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; @@ -8582,10 +8292,8 @@ template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>>; template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>>; template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; -template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_bn, 4, dequantize_iq1_bn>>; -template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_bn, 4, dequantize_iq2_bn>>; -template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; -template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; +template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; +template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>; template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; @@ -8617,8 +8325,6 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_s, QK_NL, dequantize_iq2_s>>; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_s, QK_NL, dequantize_iq1_s>>; template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_m, QK_NL, dequantize_iq1_m>>; -template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_bn, 4, dequantize_iq1_bn>>; -template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_bn, 4, dequantize_iq2_bn>>; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_nl, 2, dequantize_iq4_nl>>; template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>>; template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_k, QK_NL, dequantize_iq2_k>>; @@ -8626,8 +8332,8 @@ template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>; template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>; template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; -template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; -template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; +template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; +template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>; template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; @@ -8829,7 +8535,6 @@ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q6_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q6_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>; -template [[host_name("kernel_mul_mv_id_iq2_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_tn_f32_impl>>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>; template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>; @@ -8837,7 +8542,6 @@ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq1_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_bn_f32_impl>>; -template [[host_name("kernel_mul_mv_id_iq1_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_tn_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_bn_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 68ec6126..d18b1981 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15194,8 +15194,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ4_K: break; case GGML_TYPE_IQ5_K: break; case GGML_TYPE_IQ6_K: break; - case GGML_TYPE_IQ2_TN: break; - case GGML_TYPE_IQ1_TN: break; case GGML_TYPE_IQ4_KS: break; case GGML_TYPE_IQ4_KSS: break; case GGML_TYPE_Q4_0_4_4: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 35ed68d0..5570b1fc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1016,7 +1016,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, - .row_meta_size = 0, + .row_meta_size = 2, }, [GGML_TYPE_IQ2_BN] = { .type_name = "iq2_bn", @@ -1029,34 +1029,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, - .row_meta_size = 0, - }, - [GGML_TYPE_IQ2_TN] = { - .type_name = "iq2_tn", - .blck_size = QK_K, - .type_size = sizeof(block_iq2_tn), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_iq2_tn, - .from_float = quantize_row_iq2_tn, - .from_float_ref = (ggml_from_float_t)quantize_row_iq2_tn_ref, - .vec_dot = vec_dot_iq2_tn_q8_k, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, .row_meta_size = 4, }, - [GGML_TYPE_IQ1_TN] = { - .type_name = "iq1_tn", - .blck_size = QK_K, - .type_size = sizeof(block_iq1_tn), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_iq1_tn, - .from_float = quantize_row_iq1_tn, - .from_float_ref = (ggml_from_float_t)quantize_row_iq1_tn_ref, - .vec_dot = vec_dot_iq1_tn_q8_k, - .vec_dot_type = GGML_TYPE_Q8_K64, - .nrows = 1, - .row_meta_size = 2, - }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", .blck_size = QK4_NL, @@ -3926,8 +3900,6 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break; case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; - case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break; - case GGML_FTYPE_MOSTLY_IQ1_TN: wtype = GGML_TYPE_IQ1_TN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; @@ -10428,8 +10400,6 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -10819,8 +10789,6 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -10960,8 +10928,6 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -14147,8 +14113,6 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -14528,8 +14492,6 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -14803,8 +14765,6 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -15405,8 +15365,6 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: - case GGML_TYPE_IQ2_TN: - case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: @@ -22224,8 +22182,6 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_IQ1_TN: result = quantize_iq1_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b77d08b6..d7682e54 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -755,18 +755,6 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { - DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - template <typename Q8> - inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { - new_block(i); - } - inline void new_block(int i) { - bits.prepare(x[i].qs); - } - Q2Bits bits; -}; - struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template <typename Q8> @@ -1256,7 +1244,7 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { Q4Bits bits; Scales8KBase s8k; const __m512i values; - const __m512i mask15 = _mm512_set1_epi16(0xfffe); + const __m512i mask15 = _mm512_set1_epi16(-2); // value is 0xfffe, but to shut up the stupid compiler warning we use the signed value const __m512i mask1 = _mm512_set1_epi16(1); const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); @@ -1319,22 +1307,13 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { - if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { - auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); - auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( - _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), - deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)), - deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3)); - accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); - } else { - const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); - const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); - const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); - const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); - auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); - sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); - accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); - } + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } @@ -1347,64 +1326,6 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da } } -template <int nrc_y> -static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8<nrc_y> q8(info); - - DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx); - - __m512 accd[2*nrc_y]; - - for (int ix = 0; ix < nrc_x; ix += 2) { - - for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); - - deq1.new_row(ix+0); - deq2.new_row(ix+1); - - for (int i = 0; i < nb; ++i) { - - deq1.new_block(i); - deq2.new_block(i); - //float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows - - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); - auto sumi_scales_512 = _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales_256, 0); - auto q8q = q8.load_quants64(iy, i, 0); - auto sumi_1 = _mm512_dpbusd_epi32(sumi_scales_512, deq1.bits.values[0], q8q); - auto sumi_2 = _mm512_dpbusd_epi32(sumi_scales_512, deq2.bits.values[0], q8q); - q8q = q8.load_quants64(iy, i, 1); - sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[1], q8q); - sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[1], q8q); - q8q = q8.load_quants64(iy, i, 2); - sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[2], q8q); - sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[2], q8q); - q8q = q8.load_quants64(iy, i, 3); - sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q); - sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q); - // The scale is supposed to be per per tensor, so we can use the same scale - auto vd = _mm512_set1_ps(/*d* */q8.scale(iy, i)); - accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); - accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); - // Leaving this here just in case ternary models start using per row scales - //accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); - //accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); - } - - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, deq1.d*_mm512_reduce_add_ps(accd[2*iy+0])); - info.store(ix+1, iy, deq2.d*_mm512_reduce_add_ps(accd[2*iy+1])); - } - - } -} - template <typename Dequantizer, int nrc_y> static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -1478,33 +1399,19 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); - if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { - for (int kx = 0; kx < k_nx; ++kx) { - compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd); - } - } else { - for (int kx = 0; kx < k_nx; ++kx) { - compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); - } + for (int kx = 0; kx < k_nx; ++kx) { + compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); } } if (2*(nb/2) < nb) { int i0 = 2*(nb/2); deq[0]->new_block(i0, q8, &accm, scales); - if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { - compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd); - } else { - compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); - } + compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); } - if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { - info.store(ix, 0, _mm512_reduce_add_ps(accd)); - } else { - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); - info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); - } + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); + info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); } } @@ -2066,90 +1973,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { const __m256i mh = _mm256_set1_epi8(0x30); }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { - DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - - inline void prepare(int i, int j) { - bits.prepare(x[i].qs, j); - } - - Q2Bits bits; -}; - - -template <int nrc_y> -IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%QK_K == 0); - const int nb = n/QK_K; - - Q8<nrc_y> q8(info); - DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx); - - __m256 accd[nrc_y]; - const auto m1 = _mm256_set1_epi16(1); - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq1.new_row(ix); - deq2.new_row(ix); - - for (int i = 0; i < nb; ++i) { - - if constexpr (nrc_y == 1) { - deq1.prepare(i, 0); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(0, i, 0)), - _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(0, i, 1))); - sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(0, i, 2)), - _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(0, i, 3))), sumi1); - - deq2.prepare(i, 1); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(0, i, 4)), - _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(0, i, 5))); - sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(0, i, 6)), - _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(0, i, 7))), sumi2); - auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(0, i))); - auto vd = _mm256_set1_ps(deq1.d*q8.scale(0, i)); - auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi)); - accd[0] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[0]) : _mm256_mul_ps(vd, sf); - } - else { - - deq1.prepare(i, 0); deq2.prepare(i, 1); - for (int iy = 0; iy < nrc_y; ++iy) { - auto vd = _mm256_set1_ps(deq1.d*q8.scale(iy, i)); - auto sumi = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)), - _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1))); - sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)), - _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3))), sumi); - sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)), - _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi); - sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)), - _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi); - sumi = _mm256_sub_epi16(sumi, q8.load_bsums(iy, i)); - - //auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)), - // _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1))); - //auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)), - // _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3))); - //sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)), - // _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi1); - //sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)), - // _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi2); - //auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(iy, i))); - auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi)); - accd[iy] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[iy]) : _mm256_mul_ps(vd, sf); - } - } - - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); - } - - } -} - template <typename Dequantizer, int nrc_y> static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -2471,7 +2294,7 @@ struct DequantizerIQ1BN { }; -template <int nrc_y, bool is_iq1_tn> +template <int nrc_y> IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); @@ -2486,14 +2309,14 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const const block_iq1_bn * x; const char * cx0 = (const char *)vx; float scale; + ggml_half d16; for (int ix = 0; ix < nrc_x; ++ix) { const char * cx = cx0 + ix*bx; - if constexpr (is_iq1_tn) { - scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx); - cx += sizeof(ggml_half); - } + std::memcpy(&d16, cx, sizeof(d16)); + scale = GGML_FP16_TO_FP32(d16); + cx += sizeof(d16); x = (const block_iq1_bn *)cx; if constexpr (nrc_y == 1) { @@ -2561,17 +2384,13 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const auto vd = q8.scale(iy); auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); - if constexpr (is_iq1_tn) { - info.store(ix, iy, scale*hsum_float_4(sumf)); - } else { - info.store(ix, iy, hsum_float_4(sumf)); - } + info.store(ix, iy, scale*hsum_float_4(sumf)); } } } -struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> { +struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> { DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const { @@ -2671,7 +2490,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto vd = q8.scale(iy); auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); - info.store(ix, iy, hsum_float_4(sumf)); + info.store(ix, iy, deq.d*hsum_float_4(sumf)); } } } @@ -4075,30 +3894,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerQ2K>(mm); break; - case GGML_TYPE_IQ2_TN: - assert (ne00 % QK_K == 0); -#ifdef HAVE_FANCY_SIMD - //MulMat::set_functions<DequantizerIQ2TN>(mm); - mm.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<DequantizerIQ2TN>; - //mm.funcs[0] = mul_mat_iq2tn_q8_K_AVX512<1>; - mm.funcs[1] = mul_mat_iq2tn_q8_K_AVX512<2>; - mm.funcs[2] = mul_mat_iq2tn_q8_K_AVX512<3>; - mm.funcs[3] = mul_mat_iq2tn_q8_K_AVX512<4>; - mm.funcs[4] = mul_mat_iq2tn_q8_K_AVX512<5>; - mm.funcs[5] = mul_mat_iq2tn_q8_K_AVX512<6>; - mm.funcs[6] = mul_mat_iq2tn_q8_K_AVX512<7>; - mm.funcs[7] = mul_mat_iq2tn_q8_K_AVX512<8>; -#else - mm.funcs[0] = mul_mat_iq2tn_q8_K<1>; - mm.funcs[1] = mul_mat_iq2tn_q8_K<2>; - mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; - mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; - mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; - mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; - mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; - mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; -#endif - break; case GGML_TYPE_Q3_K: assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerQ3K>(mm); @@ -4173,26 +3968,14 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_IQ1_BN: assert (ne00 % QK_IQ1BN == 0); - mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>; - mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>; - mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>; - mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>; - mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>; - mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>; - mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>; - mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>; - expected_typeB = GGML_TYPE_Q8_K64; - break; - case GGML_TYPE_IQ1_TN: - assert (ne00 % QK_IQ1BN == 0); - mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>; - mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>; - mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>; - mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>; - mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>; - mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>; - mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>; - mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>; + mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>; + mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>; + mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>; + mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>; + mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>; + mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>; + mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>; + mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>; expected_typeB = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: @@ -5410,156 +5193,6 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { - DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return true; } - - //template <typename Q8> - //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { - // d = GGML_FP16_TO_FP32(x[i].d); - //} - - inline void new_block(int) { } - - template <typename Q8> - inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto q8b_1 = q8.load_quants(iy, i, 4*j+0); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); - - auto q8b_2 = q8.load_quants(iy, i, 4*j+1); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); - - auto q8b_3 = q8.load_quants(iy, i, 4*j+2); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), - vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); - - auto q8b_4 = q8.load_quants(iy, i, 4*j+3); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), - vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); - } - } - template <typename Q8> - inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) { - auto q8b_1 = q8.load_quants(0, i, 4*j+0); - sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); - - auto q8b_2 = q8.load_quants(0, i, 4*j+1); - sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); - - q8b_1 = q8.load_quants(0, i, 4*j+2); - sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]); - - q8b_2 = q8.load_quants(0, i, 4*j+3); - sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]); - } - - IQK_ALWAYS_INLINE void prepare(int i, int j) { - bits.prepare(x[i].qs+32*j); - auto m1 = vdupq_n_s8(1); - for (int k = 0; k < 4; ++k) { - bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1); - bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1); - } - } - - Q2bits bits; -}; - -template <int nrc_y> -void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8<nrc_y, block_q8_K> q8(info); - - DequantizerIQ2TN deq(vx, bx, nrc_y); - float32x4_t acc[nrc_y]; - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - for (int i = 0; i < nb; ++i) { - - int32x4_t sumi[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); - - deq.new_block(i); - deq.prepare(i, 0); - deq.compute(q8, i, 0, sumi); - deq.prepare(i, 1); - deq.compute(q8, i, 1, sumi); - - if (i > 0) { - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); - } - } else { - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); - } - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(acc[iy])); - } - } -} -void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8<1, block_q8_K> q8(info); - - DequantizerIQ2TN deq(vx, bx, 1); - - auto m1 = vdup_n_s16(-1); - float32x4_t acc[2]; - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - for (int i = 0; i < nb; ++i) { - - int32x4_t sumi[2] = {}; - deq.new_block(i); - auto bsums = q8.load_bsums(0, i); - bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]); - sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1); - sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1); - deq.bits.prepare(deq.x[i].qs); - deq.compute1(q8, i, 0, sumi); - deq.bits.prepare(deq.x[i].qs+32); - deq.compute1(q8, i, 1, sumi); - - auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); - if (i > 0) { - acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); - acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); - } else { - acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd); - acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd); - } - - } - - acc[0] = vaddq_f32(acc[0], acc[1]); - info.store(ix, 0, vaddvq_f32(acc[0])); - } -} - - template <int nrc_y, typename Dequantizer> void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -6630,7 +6263,7 @@ struct DequantizerIQ1BN { } }; -template <int nrc_y, bool is_iq1_tn> +template <int nrc_y> static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; @@ -6641,14 +6274,16 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn int8x16x4_t v1, v2; float scale; + ggml_half d16; + char * c16 = (char *)&d16; for (int ix = 0; ix < nrc_x; ++ix) { const char * cx = ((const char *)vx + ix*bx); - if constexpr (is_iq1_tn) { - scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx); - cx += sizeof(ggml_half); - } + c16[0] = cx[0]; c16[1] = cx[1]; + //std::memcpy(&d16, cx, sizeof(d16)); + cx += sizeof(d16); + scale = GGML_FP16_TO_FP32(d16); const block_iq1_bn * x = (const block_iq1_bn *)cx; @@ -6704,11 +6339,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - if constexpr (is_iq1_tn) { - info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); - } else { - info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); - } + info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } } @@ -6726,7 +6357,9 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int ix = 0; ix < nrc_x; ++ix) { - const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); + const float * dptr = (const float *)((const char *)vx + ix*bx); + const float d = *dptr; + const block_iq2_bn * x = (const block_iq2_bn *)(dptr + 1); if constexpr (nrc_y == 1) { int8x16x4_t v1; @@ -6789,7 +6422,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + info.store(ix, iy, -d*vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } } } @@ -6859,17 +6492,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q2_K: MulMat::set_functions<DequantizerQ2K>(m); break; - case GGML_TYPE_IQ2_TN: - //MulMat::set_functions<DequantizerIQ2TN>(m); - m.funcs[0] = mul_mat_iq2tn_K_q8_K_1; - m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>; - m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>; - m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>; - m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>; - m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>; - m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>; - m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>; - break; case GGML_TYPE_Q3_K: MulMat::set_functions<DequantizerQ3K>(m); break; @@ -6925,25 +6547,14 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions<DequantizerIQ3S>(m); break; case GGML_TYPE_IQ1_BN: - m.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>; - m.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>; - m.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>; - m.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>; - m.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>; - m.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>; - m.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>; - m.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>; - expected_Btype = GGML_TYPE_Q8_K64; - break; - case GGML_TYPE_IQ1_TN: - m.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>; - m.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>; - m.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>; - m.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>; - m.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>; - m.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>; - m.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>; - m.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>; + m.funcs[0] = mul_mat_iq1bn_q8_K64<1>; + m.funcs[1] = mul_mat_iq1bn_q8_K64<2>; + m.funcs[2] = mul_mat_iq1bn_q8_K64<3>; + m.funcs[3] = mul_mat_iq1bn_q8_K64<4>; + m.funcs[4] = mul_mat_iq1bn_q8_K64<5>; + m.funcs[5] = mul_mat_iq1bn_q8_K64<6>; + m.funcs[6] = mul_mat_iq1bn_q8_K64<7>; + m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; expected_Btype = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 26bc5ecb..b9d48237 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -119,6 +119,16 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i const int nblock = n_per_row/QK_IQ1BN; + ggml_half * dptr = (ggml_half *)y; + y = (block_iq1_bn *)(dptr + 1); + + float max = 0; + for (int j = 0; j < n_per_row; ++j) max = std::max(max, fabsf(src[j])); + ggml_half d = GGML_FP32_TO_FP16(max); + std::memcpy(dptr, &d, sizeof(d)); + + float thresh = 0.5f*max; + for (int ib = 0; ib < nblock; ++ib) { std::memset(&y[ib], 0, sizeof(block_iq1_bn)); auto xb = src + ib*QK_IQ1BN; @@ -128,14 +138,14 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i int idx = 0; for (int j = 0; j < 5; ++j) { float v = xb[16*i16 + 5*k + j]; - int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + int q = fabsf(v) < thresh ? 1 : v < 0 ? 0 : 2; idx += k_nb[j]*q; } idx = (256*idx + k_nb[5] - 1)/k_nb[5]; y[ib].ql[3*i16 + k] = idx; } float v = xb[16*i16 + 15]; - int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + int q = fabsf(v) < thresh ? 1 : v < 0 ? 0 : 2; v13 += k_nb[i16]*q; } y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5]; @@ -150,10 +160,18 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i constexpr int Nj = QK_IQ1BN/4; + float max = 0; + for (int j = 0; j < n_per_row; ++j) max = std::max(max, fabsf(src[j])); + + float * dptr = (float *)y; + *dptr = max; + y = (block_iq2_bn *)(dptr + 1); + float thresh = 0.5f*max; + for (int ib = 0; ib < nblock; ++ib) { auto xb = src + QK_IQ1BN*ib; for (int j = 0; j < QK_IQ1BN; ++j) { - L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2; + L[j] = fabsf(xb[j]) < thresh ? 1 : xb[j] < 0 ? 0 : 2; } for (int j = 0; j < Nj; ++j) { y[ib].qs[j] = L[j] | (L[j + Nj] << 2) | (L[j + 2*Nj] << 4) | (L[j + 3*Nj] << 6); @@ -165,13 +183,13 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { IQ1BNQuantizer iq1bn; - int nblock = n_per_row/QK_IQ1BN; - block_iq1_bn * y = (block_iq1_bn *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row); + auto qrow = (char *)dst; for (int row = 0; row < nrows; ++row) { - iq1bn.quantize_one_row_1bn(src + row*n_per_row, y, n_per_row, imatrix); - y += nblock; + iq1bn.quantize_one_row_1bn(src + row*n_per_row, (block_iq1_bn *)qrow, n_per_row, imatrix); + qrow += row_size; } - return sizeof(block_iq1_bn)*nblock*nrows; + return nrows*row_size; } void quantize_row_iq1_bn_ref(const float * x, block_iq1_bn * y, int64_t k) { @@ -182,54 +200,6 @@ void quantize_row_iq1_bn(const float * x, void * y, int64_t k) { quantize_iq1_bn(x, y, 1, k, nullptr); } -void quantize_row_iq1_tn_ref(const float * x, block_iq1_tn * y, int64_t k) { - quantize_iq1_tn(x, (void *)y, 1, k, nullptr); -} - -void quantize_row_iq1_tn(const float * x, void * y, int64_t k) { - quantize_iq1_tn(x, y, 1, k, nullptr); -} - -size_t quantize_iq1_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { - GGML_ASSERT(n_per_row >= 2*QK_K); // so we have space for the scale - int nblock = n_per_row/QK_IQ1BN; - float tmp[QK_IQ1BN]; - char * qrow = (char *)dst; - auto row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row); - IQ1BNQuantizer iq1bn; - for (int row = 0; row < nrows; ++row) { - float max = fabsf(src[0]); - for (int j = 1; j < n_per_row; ++j) max = std::max(max, fabsf(src[j])); - if (!(max > 0)) printf("%s: found max = %g?\n", __func__, max); - //GGML_ASSERT(max > 0); - *(ggml_half *)qrow = GGML_FP32_TO_FP16(max); - block_iq1_bn * y = (block_iq1_bn *)(qrow + sizeof(ggml_half)); - const float * xb = src; - for (int ib = 0; ib < nblock; ++ib) { - for (int j = 0; j < QK_IQ1BN; ++j) tmp[j] = xb[j] < -0.5f*max ? -1 : xb[j] <= 0.5f*max ? 0 : 1; - iq1bn.quantize_one_row_1bn(tmp, y, QK_IQ1BN, imatrix); - ++y; - xb += QK_IQ1BN; - } - src += n_per_row; - qrow += row_size; - } - return nrows*row_size; -} - -void dequantize_row_iq1_tn(const block_iq1_tn * x, float * y, int64_t k) { - float scale = GGML_FP16_TO_FP32(*(const ggml_half *)x); - const block_iq1_bn * iq1bn = (const block_iq1_bn *)((const char *)x + sizeof(ggml_half)); - dequantize_row_iq1_bn(iq1bn, y, k); - for (int j = 0; j < int(k); ++j) y[j] *= scale; -} - -void vec_dot_iq1_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { - float scale = GGML_FP16_TO_FP32(*(const ggml_half *)vx); - ggml_vec_dot_iq1_bn_q8_K64(n, s, bs, (const void *)((const char *)vx + sizeof(ggml_half)), bx, vy, by, nrc); - *s *= scale; -} - void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { assert(k%QK_IQ1BN == 0); int nblock = k / QK_IQ1BN; @@ -255,13 +225,13 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { size_t quantize_iq2_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { IQ1BNQuantizer iq1bn; - int nblock = n_per_row/QK_IQ1BN; - block_iq2_bn * y = (block_iq2_bn *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row); + auto qrow = (char *)dst; for (int row = 0; row < nrows; ++row) { - iq1bn.quantize_one_row_2bn(src + row*n_per_row, y, n_per_row, imatrix); - y += nblock; + iq1bn.quantize_one_row_2bn(src + row*n_per_row, (block_iq2_bn *)qrow, n_per_row, imatrix); + qrow += row_size; } - return sizeof(block_iq2_bn)*nblock*nrows; + return nrows*row_size; } void quantize_row_iq2_bn_ref(const float * x, block_iq2_bn * y, int64_t k) { @@ -2369,114 +2339,6 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe return nrows * nblock * sizeof(block_iq6_k); } -// -// ========================== IQ2_TN -// - -void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { - GGML_ASSERT(k%QK_K == 0); - - int nb = k/QK_K; - - auto quantize = [] (float xmax, float x) { - return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2; - }; - int n = k; - float max = x[0]; - for (int j = 1; j < n; ++j) max = std::max(max, fabsf(x[j])); - - *(float *)y = max; - y = (block_iq2_tn *)((float *)y + 1); - - for (int ibl = 0; ibl < nb; ++ibl) { - auto xb = x + QK_K*ibl; - auto qs = y[ibl].qs; - for (int l = 0; l < QK_K/128; ++l) { - for (int j = 0; j < 32; ++j) { - qs[j] = quantize(max, xb[j]) | (quantize(max, xb[j+32]) << 2) | (quantize(max, xb[j+64]) << 4) | (quantize(max, xb[j+96]) << 6); - } - xb += 128; - qs += 32; - } - } -} - -void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { - quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k); -} - -size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * /*imatrix*/) { - auto row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); - char * qrow = (char *)dst; - for (int row = 0; row < nrows; ++row) { - quantize_row_iq2_tn_ref(src, (block_iq2_tn *)qrow, n_per_row); - qrow += row_size; - src += n_per_row; - } - return row_size*nrows; -} - -void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) { - GGML_ASSERT(k%QK_K == 0); - const float * dptr = (const float *)x; - float d = *dptr; - x = (const block_iq2_tn *)(dptr + 1); - int nb = k/QK_K; - for (int ibl = 0; ibl < nb; ++ibl) { - auto qs = x[ibl].qs; - for (int l = 0; l < QK_K/128; ++l) { - for (int j = 0; j < 32; ++j) { - y[j+ 0] = d*((qs[j] >> 0) & 3) - d; - y[j+32] = d*((qs[j] >> 2) & 3) - d; - y[j+64] = d*((qs[j] >> 4) & 3) - d; - y[j+96] = d*((qs[j] >> 6) & 3) - d; - } - y += 128; - qs += 32; - } - } -} - -void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { - GGML_UNUSED(bs); - GGML_UNUSED(bx); - GGML_UNUSED(by); - GGML_UNUSED(nrc); -#if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { - return; - } -#endif - - const int nb = n / QK_K; - - const float * dptr = (const float *)vx; - const float d = *dptr; - const block_iq2_tn * x = (const block_iq2_tn *)(dptr + 1); - const block_q8_K * y = (const block_q8_K *)vy; - - float sumf = 0; - - for (int i = 0; i < nb; i++) { - auto qs = x[i].qs; - auto q8 = y[i].qs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0; - for (int j = 0; j < QK_K/16; ++j) sumi1 -= y[i].bsums[j]; - for (int l = 0; l < QK_K/128; ++l) { - for (int j = 0; j < 32; ++j) { - sumi1 += q8[j+ 0] * (qs[j] & 0x03); - sumi2 += q8[j+32] * (qs[j] & 0x0c); - sumi3 += q8[j+64] * (qs[j] & 0x30); - sumi4 += q8[j+96] * (qs[j] & 0xc0); - } - q8 += 128; - qs += 32; - } - sumf += d * (sumi1 + 0.25f*sumi2 + 0.0625f*sumi3 + 0.015625f*sumi4); - } - *s = sumf; -} - #ifdef __AVX2__ namespace { inline int hsum_i32_8(const __m256i a) { @@ -2941,7 +2803,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, continue; } float best = 0; - bool is_shifted = false; float d = -max/iq4k_values[0]; std::memset(vs, 0, block_size); for (int itry = -ntry; itry <= ntry; ++itry) { @@ -2974,10 +2835,10 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, } bool copy_p = false, copy_m = false; if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true; + d = sumqx_p/sumq2_p; best = d * sumqx_p; copy_p = true; } if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; copy_m = true; + d = sumqx_m/sumq2_m; best = d * sumqx_m; copy_m = true; } if (copy_m) { std::memcpy(vs, vms, block_size); @@ -3014,10 +2875,10 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, } copy_p = copy_m = false; if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true; + d = sumqx_p/sumq2_p; best = d * sumqx_p; copy_p = true; } if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true; + d = sumqx_m/sumq2_m; best = d * sumqx_m; copy_m = true; } if (copy_m) { std::memcpy(vs, vms, block_size); diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index e0dde0d8..50c425af 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -43,18 +43,6 @@ size_t quantize_iq6_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, void dequantize_row_iq6_k(const block_iq6_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq6_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - -void quantize_row_iq1_tn_ref(const float * GGML_RESTRICT x, block_iq1_tn * GGML_RESTRICT y, int64_t k); -void quantize_row_iq1_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -size_t quantize_iq1_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_iq1_tn(const block_iq1_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void vec_dot_iq1_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - void quantize_row_iq4_ks_ref(const float * GGML_RESTRICT x, block_iq4_ks * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq4_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index b2906693..965e5f50 100644 --- a/include/llama.h +++ b/include/llama.h @@ -175,8 +175,6 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_K = 140, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_K = 141, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ6_K = 142, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ2_TN = 143, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ1_TN = 144, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KS = 145, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 4ca1bd11..27ba5d2f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8,7 +8,6 @@ #include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" -#include "../ggml/src/ggml-impl.h" #ifdef GGML_USE_RPC # include "ggml-rpc.h" @@ -2718,17 +2717,6 @@ struct llama_model { } }; -// Object used to allow caching of GGML graph between tokens where possible. -struct ggml_cached_graph { - bool is_active = false; - ggml_cgraph * gf; - size_t n; - ggml_backend_t backend_res; - ggml_backend_t backend_embd; - struct ggml_tensor * res; - struct ggml_tensor * embd; -}; - struct llama_context { llama_context(const llama_model & model) : model(model) @@ -2829,8 +2817,6 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] - - struct ggml_cached_graph cached_graph; }; struct llama_lora_weight { @@ -3862,8 +3848,6 @@ struct llama_model_loader { case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; - case GGML_TYPE_IQ1_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_TN; break; - case GGML_TYPE_IQ2_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_TN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; @@ -4579,9 +4563,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.625 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.00 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; @@ -13329,7 +13311,7 @@ struct llm_build_context { float q_scale; std::memcpy(&q_scale, model.layers[il].wq->op_params, sizeof(float)); // Note: we could save this scale operation by applying the Q scale on the K * Q product further down // (which also uses a scale). This works on the CPU and Metal backends, but produces NaNs on CUDA. - Qcur = ggml_scale(ctx0, Qcur, q_scale); + if (fabsf(q_scale-1) > 1e-4f) Qcur = ggml_scale(ctx0, Qcur, q_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -13339,7 +13321,7 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); float k_scale; std::memcpy(&k_scale, model.layers[il].wk->op_params, sizeof(float)); - Kcur = ggml_scale(ctx0, Kcur, k_scale); + if (fabsf(k_scale-1) > 1e-4f) Kcur = ggml_scale(ctx0, Kcur, k_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -13349,13 +13331,12 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); float v_scale; std::memcpy(&v_scale, model.layers[il].wv->op_params, sizeof(float)); - cb(Vcur, "Vcur", il); if (model.layers[il].bv) { - Vcur = ggml_scale(ctx0, Vcur, v_scale); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); + if (fabsf(v_scale-1) > 1e-4f) Vcur = ggml_scale(ctx0, Vcur, v_scale); v_scale = 1; + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); } + cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, @@ -13371,56 +13352,10 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_head_v = hparams.n_embd_head_v; - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - float kq_scale = 1.0f/sqrtf(float(n_embd_head)); - // We would use this if we did not apply the Q scale above. Sadly, this fails on CUDA. - //float kq_scale = q_scale/sqrtf(float(n_embd_head)); - struct ggml_tensor * cur_attn; - struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - cb(q, "q", il); - - struct ggml_tensor * k = - ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - 0); - cb(k, "k", il); - - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - GGML_ASSERT(kv_self.size == n_ctx); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur_attn, "kqv_merged_cont", il); + ggml_tensor * cur_attn = llm_build_kv(ctx0, lctx, kv_self, gf, + // we cannot pass model.layers[il].wo and model.layers[il].bo because we need to do rms_norm first + nullptr, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cur_attn = llm_build_norm(ctx0, cur_attn, hparams, model.layers[il].attn_sub_norm, NULL, @@ -13431,7 +13366,7 @@ struct llm_build_context { cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); float wo_scale; std::memcpy(&wo_scale, model.layers[il].wo->op_params, sizeof(float)); - cur = ggml_scale(ctx0, cur, wo_scale); + if (fabsf(wo_scale-1) > 1e-4f) cur = ggml_scale(ctx0, cur, wo_scale); cb(cur, "kqv_out", il); } @@ -13460,7 +13395,7 @@ struct llm_build_context { cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); float ffn_gate_scale; std::memcpy(&ffn_gate_scale, model.layers[il].ffn_gate->op_params, sizeof(float)); - cur = ggml_scale(ctx0, cur, ffn_gate_scale); + if (fabsf(ffn_gate_scale-1) > 1e-4f) cur = ggml_scale(ctx0, cur, ffn_gate_scale); cb(cur, "ffn_gate", il); @@ -13479,7 +13414,7 @@ struct llm_build_context { cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); float ffn_down_scale; std::memcpy(&ffn_down_scale, model.layers[il].ffn_down->op_params, sizeof(float)); - cur = ggml_scale(ctx0, cur, ffn_down_scale); + if (fabsf(ffn_down_scale-1) > 1e-4f) cur = ggml_scale(ctx0, cur, ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -15005,44 +14940,11 @@ static int llama_decode_internal( ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf; - // the output is always the last tensor in the graph - struct ggml_tensor * res; - struct ggml_tensor * embd; - - bool n_has_changed_since_last_token = false; - if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true; - lctx.cached_graph.n = kv_self.n; - - // Re-build graph only if graph caching is not possible - if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) { - - gf = llama_build_graph(lctx, u_batch, false); - - // Set whether GGML graph caching is in use within GGML module, based on - // whether caching was activated here during the previous token - ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active); - - // Disable future graph caching in presence of env var, - // if there are multiple devices, if batch size is greater than 1, - // or if nsplits is not 2. - // TO DO enable graph caching for these cases - bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr) - || (llama_get_device_count(model) > 1) - || (ggml_backend_sched_get_n_splits(lctx.sched) != 2); - for (int i = 0 ; i < gf->n_nodes; i++) { - if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) { - disable_cached_ggml_graph = true; - break; - } - } - - // Set whether graph caching should be used for future tokens - lctx.cached_graph.is_active=!disable_cached_ggml_graph; + ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); // the output is always the last tensor in the graph - res = gf->nodes[gf->n_nodes - 1]; - embd = gf->nodes[gf->n_nodes - 2]; + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; if (lctx.n_outputs == 0) { // no output @@ -15062,58 +14964,9 @@ static int llama_decode_internal( embd = nullptr; // do not extract embeddings when not needed GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } - lctx.cached_graph.res = res; - lctx.cached_graph.embd = embd; // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); ggml_backend_sched_alloc_graph(lctx.sched, gf); - } - else { - gf = lctx.cached_graph.gf; - res = lctx.cached_graph.res; - embd = lctx.cached_graph.embd; - } - lctx.cached_graph.gf = gf; - - // Update K and V cache parameters in cached graph. - if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) { - - const struct llama_hparams & hparams = model.hparams; - const int64_t kv_head = kv_self.head; - - for (int i = 0; i < gf->n_nodes; i++) { - ggml_tensor * node = gf->nodes[i]; - if (node->op == GGML_OP_CPY) { - - // K cache - const char* k_prefix = "k_cache_view-"; - if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) { - int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - ggml_tensor * tmp_tensor = kv_self.k_l[il]; - size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head; - node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset; - } - - // V cache - const char* v_prefix = "v_cache_view-"; - if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) { - int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - ggml_tensor * tmp_tensor = kv_self.v_l[il]; - size_t tmp_offset; - if (cparams.flash_attn) { - tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); - } else { - tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]); - } - node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset; - } - - } - } - - } llama_set_inputs(lctx, u_batch); @@ -15137,18 +14990,12 @@ static int llama_decode_internal( // extract logits if (res) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(lctx.logits != nullptr); float * logits_out = lctx.logits + n_outputs_prev*n_vocab; const int32_t n_outputs_new = lctx.n_outputs; - if(!ggml_use_cached_graph(lctx.sched)) - lctx.cached_graph.backend_res = backend_res; - else - backend_res = lctx.cached_graph.backend_res; - - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(lctx.logits != nullptr); - if (n_outputs_new) { GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); @@ -15159,10 +15006,6 @@ static int llama_decode_internal( // extract embeddings if (embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); - if(!ggml_use_cached_graph(lctx.sched)) - lctx.cached_graph.backend_embd = backend_embd; - else - backend_embd = lctx.cached_graph.backend_embd; GGML_ASSERT(backend_embd != nullptr); switch (cparams.pooling_type) { @@ -15903,9 +15746,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) { new_type = GGML_TYPE_IQ4_NL; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_TN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) { - new_type = GGML_TYPE_Q4_K; - } else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { new_type = GGML_TYPE_Q4_0; @@ -16154,8 +15994,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || - new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN || - new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN || new_type == GGML_TYPE_IQ4_KS || + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; @@ -16182,8 +16022,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ1_TN: - case GGML_TYPE_IQ2_TN: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_IQ2_K: @@ -16297,8 +16135,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; - case LLAMA_FTYPE_MOSTLY_IQ1_TN: default_type = GGML_TYPE_IQ1_TN; break; - case LLAMA_FTYPE_MOSTLY_IQ2_TN: default_type = GGML_TYPE_IQ2_TN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; |