diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-27 08:16:06 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-27 08:16:06 +0300 |
commit | 6dec4af4b6e65eb72e646a6f8b10d77c9d306281 (patch) | |
tree | b69a6dfdd024ccf6a4d7490666664cbac4bc65ce | |
parent | 546f3ef349a7082fbc349897c3c7246baed2a6c6 (diff) |
Adding ability to have meta data per tensor row (#61)
* POC: per row scale
This is a POC how to work around opinionated ggml to
have scales per row rather than per block.
Only implemened for Zen4 and only for iq2_tn.
* POC per row scale: iq2_tn on NEON
* POC per row scale: iq2_tn on Metal
* Per row scale Metal templates
* iq1_tn: shrink to 1.625 bpw (NEON and Metal)
* POC per row scale: CUDA
* POC per row scale: add CUDA TODOs
There are two places in ggml-cuda.cu left where it is assumed
that type_size * n_per_row / block_size is the way to compute
and handle row sizes. This does not affect simple usage,
but will lead to issues when tensors are split between GPUs.
* Per row scales - CUDA
The only place left where there are unnecessary assumptions being made
is in the Flash Attention code. As we are not using any quants that
use per row scales for quantized KV cache, it should be OK for now.
* Update IQ1_TN and IQ2_TN bpw shown to user
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 4 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 19 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 17 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 66 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 2 | ||||
-rw-r--r-- | ggml/src/ggml.c | 84 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 69 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 23 | ||||
-rw-r--r-- | src/llama.cpp | 4 |
12 files changed, 171 insertions, 138 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c6153e45..c11b8631 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -28,8 +28,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, - { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.69 bpw quantization (TriLM)", }, - { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", }, + { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.63 bpw quantization (TriLM)", }, + { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.00 bpw quantization (TriLM)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5b46a70d..6ac30b0f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -744,6 +744,7 @@ extern "C" { GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + // TODO: remove the following from the public API to avoid unnecessary assumptions about data layout GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row @@ -2517,6 +2518,7 @@ extern "C" { int64_t ncols; // number of columns to process simultaneously ggml_gemv_t gemv; ggml_gemm_t gemm; + int64_t row_meta_size; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 40a4b53c..bb0c4864 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -400,14 +400,13 @@ static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/paddi // TriLM - implemented as 2.0625 bpw // typedef struct { - uint8_t qs[54]; + uint8_t qs[52]; } block_iq1_tn; -static_assert(sizeof(block_iq1_tn) == 54, "wrong iq1_tn block size/padding"); +static_assert(sizeof(block_iq1_tn) == 52, "wrong iq1_tn block size/padding"); typedef struct { - ggml_half d; uint8_t qs[QK_K/4]; } block_iq2_tn; -static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding"); +static_assert(sizeof(block_iq2_tn) == QK_K/4, "wrong iqt_bn block size/padding"); // Used by IQ1_M quants typedef union { diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 87d7e17e..ca57efbd 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1180,19 +1180,22 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const int64_t nb3 = src->nb[3]; const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); + const int64_t rs = ggml_row_size(type, ne0); const int64_t bs = ggml_blck_size(type); int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*ne0/bs) { + if (nb0 == ts && nb1 == rs) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream); } else if (nb0 == ts) { + // TODO: this only works if the row does not contain meta data return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream); } else { for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + void * rd = (void *) (dst_ptr + i1*rs); // pretend the row is a matrix with cols=1 + // TODO: this only works if the row does not contain meta data cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream); if (r != cudaSuccess) { return r; @@ -1441,8 +1444,7 @@ static void ggml_cuda_op_mul_mat( const int64_t i02_divisor = ne12 / ne02; - const size_t src0_ts = ggml_type_size(src0->type); - const size_t src0_bs = ggml_blck_size(src0->type); + const size_t src0_rs = ggml_row_size(src0->type, ne00); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; @@ -1608,7 +1610,7 @@ static void ggml_cuda_op_mul_mat( } // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * ne01*src0_rs; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 4b1be7c1..c74b030b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -154,19 +154,23 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t } template<typename dst_t> -static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) { +static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy, + int64_t n_per_row, int64_t row_size) { - const int64_t i = blockIdx.x; - const block_iq2_tn * x = (const block_iq2_tn *) vx; + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float d = *(const float *)cx; + const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float)); + int64_t i = ii - (row*n_per_row)/QK_K; const int64_t tid = threadIdx.x; const int64_t n = tid/32; const int64_t l = tid - 32*n; const uint8_t q = x[i].qs[32*n + l]; - dst_t * y = yy + i*QK_K + 128*n; + dst_t * y = yy + ii*QK_K + 128*n; - float d = __half2float(x[i].d); y[l+ 0] = d * ((q >> 0) & 3) - d; y[l+32] = d * ((q >> 2) & 3) - d; y[l+64] = d * ((q >> 4) & 3) - d; @@ -743,8 +747,9 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n template<typename dst_t> static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int nb = k / QK_K; - dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); + const int nb = (k + 255) / 256; + dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size); } template<typename dst_t> diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index a890f6b3..b2c32c0c 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -519,7 +519,8 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx; + float scale = *(const float *)vbq; + const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx; const int bq8_offset = QR2_K * (iqs / QI8_1); @@ -533,19 +534,7 @@ static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0)); v >>= 2; } - return __half2float(bq2->d) * sumf; - - //float sumf_d = 0; - //float sumf_m = 0; - //for (int i = 0; i < QR2_K; ++ i) { - // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); - // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds); - // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0); - // sumf_m += d8.y; - // v >>= 2; - //} - //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m); - + return scale * sumf; } static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1( diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 259fa609..e2e45029 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3778,21 +3778,21 @@ void kernel_mul_mv_iq2_tn_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; + const int row_size = nb*sizeof(block_iq2_tn) + 4; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size; - device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const char * cx = (device const char *) src0 + first_row*row_size + offset0; + device const float * y = (device const float*) src1 + r1*ne10 + im*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_iq2_tn) * nb / 2; + float drow[N_DST]; const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 @@ -3802,6 +3802,8 @@ void kernel_mul_mv_iq2_tn_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + for (int row = 0; row < N_DST; row++) drow[row] = *((device const float *)(cx + row*row_size)); + for (int ib = ix; ib < nb; ib += 4) { float sumy = 0.f; @@ -3812,7 +3814,7 @@ void kernel_mul_mv_iq2_tn_f32_impl( yl[i+24] = y4[i+96]; sumy += yl[i+24]; } - device const half * dh = &x[ib].d; + device const block_iq2_tn * x = (device const block_iq2_tn *)(cx + 4); device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; for (int row = 0; row < N_DST; row++) { @@ -3829,14 +3831,12 @@ void kernel_mul_mv_iq2_tn_f32_impl( acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); } - float dall = dh[0]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy); + sumf[row] += (acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy; - qs += step; - dh += step; + qs += row_size/2; } y4 += 4 * QK_K; @@ -3845,7 +3845,7 @@ void kernel_mul_mv_iq2_tn_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = drow[row]*all_sum; } } } @@ -5504,7 +5504,7 @@ void kernel_mul_mv_iq1_tn_f32_impl( // Why are we not passing in src0->nb[0]? // But because we are not, we need to use this hack - const uint row_size = sizeof(block_iq1_tn)*(ne00/QK_K); + const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K); const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; @@ -6850,16 +6850,14 @@ void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & re template <typename type4x4> void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) { - const half d = xb->d; device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1); il = (il/2)%4; half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const half dl = d * coef; for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - d; + reg[i/4][i%4] = coef * (q[i] & mask) - 1; } } @@ -7446,30 +7444,28 @@ struct DefaultDequantizer { short il; }; -template <typename T4x4> -struct DequantizerIQ1TN { +template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&)> +struct DequantizerRS{ using type4x4 = T4x4; - using Block = block_iq1_bn;; - DequantizerIQ1TN(device const char * cx, short il = 0) : il(il) { - d = *(device const half *)cx; - x = (device const Block *)(cx + sizeof(half)); + DequantizerRS(device const char * cx, short il = 0) : il(il) { + d = *(device const Scale *)cx; + x = (device const Block *)(cx + sizeof(Scale)); } inline void convert(thread T4x4& t) const { - dequantize_iq1_bn(x, il, t); + dequantize(x, il, t); t *= d; } inline void convert(int64_t ind, thread T4x4& t) { - dequantize_iq1_bn(x + ind/4, ind%4, t); + dequantize(x + ind/nl, ind%nl, t); t *= d; } inline void next() { - constexpr int short nl = 4; il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2+nl-1)/nl : x; } device const Block * x; short il; - half d; + Scale d; }; // each block_q contains 16*nl weights @@ -7821,7 +7817,6 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>; -template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_tn, QK_NL, dequantize_iq2_tn>; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>; @@ -7842,7 +7837,8 @@ template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>; template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>; -template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerIQ1TN<float4x4>>; +template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; +template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; // // matrix-matrix multiplication @@ -7862,7 +7858,6 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>; -template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>>; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>>; @@ -7883,7 +7878,8 @@ template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_bn, 4, dequantize_iq1_bn>>; template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_bn, 4, dequantize_iq2_bn>>; -template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerIQ1TN<half4x4>>; +template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; +template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; // // indirect matrix-matrix multiplication @@ -7900,7 +7896,6 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>; -template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_K, QK_NL, dequantize_q4_K>>; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_K, QK_NL, dequantize_q5_K>>; @@ -7921,7 +7916,8 @@ template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>; template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>; template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; -template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerIQ1TN<half4x4>>; +template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; +template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; // // matrix-vector multiplication diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index a9a25761..b0e70bcc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14798,7 +14798,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte return false; } - if (nbytes % ggml_type_size(type) != 0) { + if (type != GGML_TYPE_IQ2_TN && type != GGML_TYPE_IQ1_TN && nbytes % ggml_type_size(type) != 0) { fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type)); return false; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 08b292b7..2804accd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -651,24 +651,28 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = 1, .type_size = sizeof(int8_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_I16] = { .type_name = "i16", .blck_size = 1, .type_size = sizeof(int16_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_I32] = { .type_name = "i32", .blck_size = 1, .type_size = sizeof(int32_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_I64] = { .type_name = "i64", .blck_size = 1, .type_size = sizeof(int64_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_F64] = { .type_name = "f64", @@ -676,6 +680,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(double), .is_quantized = false, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_F32] = { .type_name = "f32", @@ -685,6 +690,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, .vec_dot_type = GGML_TYPE_F32, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_F16] = { .type_name = "f16", @@ -697,6 +703,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", @@ -717,6 +724,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .row_meta_size = 0, }, [GGML_TYPE_Q4_1] = { .type_name = "q4_1", @@ -733,6 +741,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .row_meta_size = 0, }, [4] = { // GGML_TYPE_Q4_2 .type_name = "DEPRECATED", @@ -745,6 +754,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, + .row_meta_size = 0, }, [5] = { // GGML_TYPE_Q4_3 .type_name = "DEPRECATED", @@ -757,6 +767,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q5_0] = { .type_name = "q5_0", @@ -773,6 +784,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, #endif .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q5_1] = { .type_name = "q5_1", @@ -785,6 +797,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q8_0] = { .type_name = "q8_0", @@ -806,6 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .row_meta_size = 0, }, [GGML_TYPE_Q8_1] = { .type_name = "q8_1", @@ -816,6 +830,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -828,6 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q3_K] = { .type_name = "q3_K", @@ -840,6 +856,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q4_K] = { .type_name = "q4_K", @@ -852,6 +869,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q5_K] = { .type_name = "q5_K", @@ -864,6 +882,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q6_K] = { .type_name = "q6_K", @@ -876,6 +895,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_XXS] = { .type_name = "iq2_xxs", @@ -888,6 +908,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_XS] = { .type_name = "iq2_xs", @@ -900,6 +921,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ3_XXS] = { .type_name = "iq3_xxs", @@ -912,6 +934,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ3_S] = { .type_name = "iq3_s", @@ -924,6 +947,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_S] = { .type_name = "iq2_s", @@ -936,6 +960,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", @@ -948,6 +973,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ1_M] = { .type_name = "iq1_m", @@ -960,6 +986,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ1_BN] = { .type_name = "iq1_bn", @@ -972,6 +999,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_BN] = { .type_name = "iq2_bn", @@ -984,6 +1012,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_TN] = { .type_name = "iq2_tn", @@ -996,6 +1025,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq2_tn_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 4, }, [GGML_TYPE_IQ1_TN] = { .type_name = "iq1_tn", @@ -1008,6 +1038,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq1_tn_q8_k, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, + .row_meta_size = 2, }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", @@ -1020,6 +1051,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ4_XS] = { .type_name = "iq4_xs", @@ -1032,6 +1064,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", @@ -1039,6 +1072,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_K), .is_quantized = true, .from_float = quantize_row_q8_K, + .row_meta_size = 0, }, [GGML_TYPE_Q8_K64] = { .type_name = "q8_K64", @@ -1046,6 +1080,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_K64), .is_quantized = true, .from_float = quantize_row_q8_K64, + .row_meta_size = 0, }, [GGML_TYPE_BF16] = { .type_name = "bf16", @@ -1058,6 +1093,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", @@ -1074,6 +1110,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 4, .gemv = ggml_gemv_q4_0_4x4_q8_0, .gemm = ggml_gemm_q4_0_4x4_q8_0, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0_4_8] = { .type_name = "q4_0_4x8", @@ -1090,6 +1127,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 4, .gemv = ggml_gemv_q4_0_4x8_q8_0, .gemm = ggml_gemm_q4_0_4x8_q8_0, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0_8_8] = { .type_name = "q4_0_8x8", @@ -1106,6 +1144,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_K] = { .type_name = "iq2_k", @@ -1118,6 +1157,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq2_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", @@ -1130,6 +1170,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq3_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ4_K] = { .type_name = "iq4_k", @@ -1142,6 +1183,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq4_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ5_K] = { .type_name = "iq5_k", @@ -1154,6 +1196,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq5_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ6_K] = { .type_name = "iq6_k", @@ -1166,6 +1209,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq6_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, }; @@ -3585,6 +3629,10 @@ GGML_CALL int64_t ggml_nrows(const struct ggml_tensor * tensor) { return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } +GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { + return type_traits[type].blck_size; +} + GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t nbytes; size_t blck_size = ggml_blck_size(tensor->type); @@ -3595,7 +3643,7 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { } } else { - nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + nbytes = tensor->nb[1]; //tensor->ne[0]*tensor->nb[0]/blck_size; for (int i = 1; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } @@ -3608,17 +3656,13 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } -GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { - return type_traits[type].blck_size; -} - GGML_CALL size_t ggml_type_size(enum ggml_type type) { return type_traits[type].type_size; } GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) { assert(ne % ggml_blck_size(type) == 0); - return ggml_type_size(type)*ne/ggml_blck_size(type); + return type_traits[type].row_meta_size + ggml_type_size(type)*ne/ggml_blck_size(type); } double ggml_type_sizef(enum ggml_type type) { @@ -3764,7 +3808,7 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { return false; } - next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); + next_nb = ggml_row_size(tensor->type, tensor->ne[0]); //next_nb*tensor->ne[0]/ggml_blck_size(tensor->type) + type_traits[tensor->type].row_meta_size; for (int i = 1; i < GGML_MAX_DIMS; i++) { if (tensor->ne[i] != 1) { if (i > n) { @@ -4227,7 +4271,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( } result->nb[0] = ggml_type_size(type); - result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type)); + result->nb[1] = ggml_row_size(type, ne[0]); for (int i = 2; i < GGML_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -13023,8 +13067,8 @@ static void ggml_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { if (counter++ % nth == ith) { if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), 0, 1)) goto IQK_MulMat_Not_Available1; } @@ -13036,8 +13080,8 @@ static void ggml_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), ith, nth)) goto IQK_MulMat_Not_Available1; return; @@ -13123,8 +13167,8 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), ith, nth)) goto IQK_MulMat_Not_Available2; return; @@ -13353,8 +13397,8 @@ static void ggml_compute_forward_mul_mat_id( #if GGML_USE_IQK_MULMAT if (ne13 == 1 && dst->type == GGML_TYPE_F32) { if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type), + src0->type, (const char *)src0_cur, nb01, ///ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata, row_size, ///ggml_type_size(vec_dot_type), (float *)dst->data, nb1, nb2, matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available; continue; @@ -13870,7 +13914,7 @@ static void ggml_compute_forward_softcap( default: { GGML_ASSERT(false); - } break; + } } } @@ -13986,7 +14030,7 @@ static void ggml_compute_forward_softcap_max( default: { GGML_ASSERT(false); - } break; + } } } @@ -18652,11 +18696,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_SOFTCAP: { GGML_ASSERT(false); // TODO: not implemented - } break; + } case GGML_OP_SOFT_CAP_MAX: { GGML_ASSERT(false); // TODO: not implemented - } break; + } case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7543d895..33b0a0d5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -140,8 +140,9 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, return false; } - auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); - auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); + size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); + size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); + //if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA)); auto nrc_x = (Nx + nth - 1)/nth; auto first_x = ith*nrc_x; @@ -165,8 +166,8 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } - auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); - auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); + size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); + size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); int nrc_x = (Nx + nth - 1)/nth; int first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; @@ -378,11 +379,17 @@ struct ScaleIQ4XS { const __m128i m32 = _mm_set1_epi16(-32); }; -template <typename Block> +template <typename Block, bool per_row_scale = false> struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} inline void new_row(int ix) { - x = (const Block *)((const char *)vx + bx*ix); + if constexpr (per_row_scale) { + const float * dptr = (const float *)((const char *)vx + bx*ix); + d = *dptr; + x = (const Block *)(dptr + 1); + } else { + x = (const Block *)((const char *)vx + bx*ix); + } } const void * vx; @@ -700,14 +707,13 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template <typename Q8> inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { new_block(i); } inline void new_block(int i) { - d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); } Q2Bits bits; @@ -1158,7 +1164,7 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D deq1.new_block(i); deq2.new_block(i); - float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows + //float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows for (int iy = 0; iy < nrc_y; ++iy) { auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); @@ -1176,7 +1182,7 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q); sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q); // The scale is supposed to be per per tensor, so we can use the same scale - auto vd = _mm512_set1_ps(d*q8.scale(iy, i)); + auto vd = _mm512_set1_ps(/*d* */q8.scale(iy, i)); accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); // Leaving this here just in case ternary models start using per row scales @@ -1187,8 +1193,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0])); - info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1])); + info.store(ix+0, iy, deq1.d*_mm512_reduce_add_ps(accd[2*iy+0])); + info.store(ix+1, iy, deq2.d*_mm512_reduce_add_ps(accd[2*iy+1])); } } @@ -4104,14 +4110,23 @@ struct Q2bits { } }; -template <typename block_q> +template <typename block_q, bool has_row_scale = false> struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} - inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); } + inline void new_row(int ix) { + if constexpr (has_row_scale) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + d = *dptr; + x = (const block_q *)(dptr + 1); + } else { + x = (const block_q *)((const char *)vx + ix*bx); + } + } const void * vx; const block_q * x; const size_t bx; const int nrc; + float d; }; struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> { @@ -4133,7 +4148,6 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> { Q4bits bits; Scales8 s8; - float d; }; struct HighBit5 { @@ -4202,7 +4216,6 @@ struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> { uint8x16x2_t hbits; - float d; }; inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { @@ -4256,7 +4269,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { const uint8x16_t mhb = vdupq_n_u8(0x30); - float d; }; struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> { @@ -4317,7 +4329,6 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> { uint8x16_t mask; HighBit3 h; - float d; }; struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { @@ -4389,7 +4400,6 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { Q2bits bits; - float d; }; // ============================= i-quants @@ -4453,7 +4463,6 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { const int8x16_t values; const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); - float d; }; struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { @@ -4503,7 +4512,6 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { const uint8x16_t hm = vdupq_n_u8(0x10); uint8x16x2_t hbits; - float d; }; struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { @@ -4538,7 +4546,6 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { const int8x16x4_t values; const uint8x16_t hm = vdupq_n_u8(0x30); - float d; }; struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { @@ -4570,7 +4577,6 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1)); const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); - float d; }; struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { @@ -4630,7 +4636,6 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010}); const uint8x16_t sign_shuffle = load_sign_shuffle(); - float d; }; struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { @@ -4688,7 +4693,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602}; - float d; }; struct SimpleBits { @@ -4747,7 +4751,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { uint32x4x4_t data; SimpleBits bits; - float d; }; inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { @@ -4793,7 +4796,6 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> { SimpleBits bits; - float d; }; @@ -4857,7 +4859,6 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { SimpleBits bits; SignHelper sh; - float d; }; @@ -4891,8 +4892,6 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> { SimpleBits bits; uint32x4x2_t gas; - float d; - }; struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { @@ -4951,11 +4950,9 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { SignHelper sh; uint32x4x2_t gas; - float d; - }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } @@ -4966,9 +4963,7 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { // d = GGML_FP16_TO_FP32(x[i].d); //} - inline void new_block(int i) { - d = GGML_FP16_TO_FP32(x[i].d); - } + inline void new_block(int) { } template <typename Q8> inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { @@ -5019,8 +5014,6 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { } Q2bits bits; - - float d; }; template <int nrc_y> diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 42441584..28bad18e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1972,15 +1972,15 @@ void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { auto quantize = [] (float xmax, float x) { return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2; }; + int n = k; + float max = x[0]; + for (int j = 1; j < n; ++j) max = std::max(max, fabsf(x[j])); + + *(float *)y = max; + y = (block_iq2_tn *)((float *)y + 1); for (int ibl = 0; ibl < nb; ++ibl) { auto xb = x + QK_K*ibl; - float max = xb[0]; - for (int j = 0; j < QK_K; ++j) { - float ax = fabsf(xb[j]); - max = std::max(ax, max); - } - y[ibl].d = GGML_FP32_TO_FP16(max); auto qs = y[ibl].qs; for (int l = 0; l < QK_K/128; ++l) { for (int j = 0; j < 32; ++j) { @@ -1992,7 +1992,7 @@ void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { } } -void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { +void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k); } @@ -2009,9 +2009,11 @@ size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) { GGML_ASSERT(k%QK_K == 0); + const float * dptr = (const float *)x; + float d = *dptr; + x = (const block_iq2_tn *)(dptr + 1); int nb = k/QK_K; for (int ibl = 0; ibl < nb; ++ibl) { - float d = GGML_FP16_TO_FP32(x[ibl].d); auto qs = x[ibl].qs; for (int l = 0; l < QK_K/128; ++l) { for (int j = 0; j < 32; ++j) { @@ -2039,13 +2041,14 @@ void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t const int nb = n / QK_K; - const block_iq2_tn * x = (const block_iq2_tn *)vx; + const float * dptr = (const float *)vx; + const float d = *dptr; + const block_iq2_tn * x = (const block_iq2_tn *)(dptr + 1); const block_q8_K * y = (const block_q8_K *)vy; float sumf = 0; for (int i = 0; i < nb; i++) { - float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; auto qs = x[i].qs; auto q8 = y[i].qs; int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0; diff --git a/src/llama.cpp b/src/llama.cpp index 0eea948a..df57c071 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4498,9 +4498,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.6875 bpw TriLM"; + case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.625 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.06 bpw TriLM"; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.00 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; |