diff options
-rw-r--r-- | examples/quantize/quantize.cpp | 4 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 19 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 17 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 66 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 2 | ||||
-rw-r--r-- | ggml/src/ggml.c | 84 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 69 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 23 | ||||
-rw-r--r-- | src/llama.cpp | 4 |
12 files changed, 171 insertions, 138 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c6153e45..c11b8631 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -28,8 +28,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, - { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.69 bpw quantization (TriLM)", }, - { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", }, + { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.63 bpw quantization (TriLM)", }, + { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.00 bpw quantization (TriLM)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5b46a70d..6ac30b0f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -744,6 +744,7 @@ extern "C" { GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + // TODO: remove the following from the public API to avoid unnecessary assumptions about data layout GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row @@ -2517,6 +2518,7 @@ extern "C" { int64_t ncols; // number of columns to process simultaneously ggml_gemv_t gemv; ggml_gemm_t gemm; + int64_t row_meta_size; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 40a4b53c..bb0c4864 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -400,14 +400,13 @@ static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/paddi // TriLM - implemented as 2.0625 bpw // typedef struct { - uint8_t qs[54]; + uint8_t qs[52]; } block_iq1_tn; -static_assert(sizeof(block_iq1_tn) == 54, "wrong iq1_tn block size/padding"); +static_assert(sizeof(block_iq1_tn) == 52, "wrong iq1_tn block size/padding"); typedef struct { - ggml_half d; uint8_t qs[QK_K/4]; } block_iq2_tn; -static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding"); +static_assert(sizeof(block_iq2_tn) == QK_K/4, "wrong iqt_bn block size/padding"); // Used by IQ1_M quants typedef union { diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 87d7e17e..ca57efbd 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1180,19 +1180,22 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const int64_t nb3 = src->nb[3]; const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); + const int64_t rs = ggml_row_size(type, ne0); const int64_t bs = ggml_blck_size(type); int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*ne0/bs) { + if (nb0 == ts && nb1 == rs) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream); } else if (nb0 == ts) { + // TODO: this only works if the row does not contain meta data return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream); } else { for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + void * rd = (void *) (dst_ptr + i1*rs); // pretend the row is a matrix with cols=1 + // TODO: this only works if the row does not contain meta data cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream); if (r != cudaSuccess) { return r; @@ -1441,8 +1444,7 @@ static void ggml_cuda_op_mul_mat( const int64_t i02_divisor = ne12 / ne02; - const size_t src0_ts = ggml_type_size(src0->type); - const size_t src0_bs = ggml_blck_size(src0->type); + const size_t src0_rs = ggml_row_size(src0->type, ne00); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; @@ -1608,7 +1610,7 @@ static void ggml_cuda_op_mul_mat( } // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * ne01*src0_rs; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 4b1be7c1..c74b030b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -154,19 +154,23 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t } template<typename dst_t> -static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) { +static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy, + int64_t n_per_row, int64_t row_size) { - const int64_t i = blockIdx.x; - const block_iq2_tn * x = (const block_iq2_tn *) vx; + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float d = *(const float *)cx; + const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float)); + int64_t i = ii - (row*n_per_row)/QK_K; const int64_t tid = threadIdx.x; const int64_t n = tid/32; const int64_t l = tid - 32*n; const uint8_t q = x[i].qs[32*n + l]; - dst_t * y = yy + i*QK_K + 128*n; + dst_t * y = yy + ii*QK_K + 128*n; - float d = __half2float(x[i].d); y[l+ 0] = d * ((q >> 0) & 3) - d; y[l+32] = d * ((q >> 2) & 3) - d; y[l+64] = d * ((q >> 4) & 3) - d; @@ -743,8 +747,9 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n template<typename dst_t> static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int nb = k / QK_K; - dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); + const int nb = (k + 255) / 256; + dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size); } template<typename dst_t> diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index a890f6b3..b2c32c0c 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -519,7 +519,8 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx; + float scale = *(const float *)vbq; + const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx; const int bq8_offset = QR2_K * (iqs / QI8_1); @@ -533,19 +534,7 @@ static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0)); v >>= 2; } - return __half2float(bq2->d) * sumf; - - //float sumf_d = 0; - //float sumf_m = 0; - //for (int i = 0; i < QR2_K; ++ i) { - // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); - // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds); - // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0); - // sumf_m += d8.y; - // v >>= 2; - //} - //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m); - + return scale * sumf; } static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1( diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 259fa609..e2e45029 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3778,21 +3778,21 @@ void kernel_mul_mv_iq2_tn_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; + const int row_size = nb*sizeof(block_iq2_tn) + 4; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size; - device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const char * cx = (device const char *) src0 + first_row*row_size + offset0; + device const float * y = (device const float*) src1 + r1*ne10 + im*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_iq2_tn) * nb / 2; + float drow[N_DST]; const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 @@ -3802,6 +3802,8 @@ void kernel_mul_mv_iq2_tn_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + for (int row = 0; row < N_DST; row++) drow[row] = *((device const float *)(cx + row*row_size)); + for (int ib = ix; ib < nb; ib += 4) { float sumy = 0.f; @@ -3812,7 +3814,7 @@ void kernel_mul_mv_iq2_tn_f32_impl( yl[i+24] = y4[i+96]; sumy += yl[i+24]; } - device const half * dh = &x[ib].d; + device const block_iq2_tn * x = (device const block_iq2_tn *)(cx + 4); device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; for (int row = 0; row < N_DST; row++) { @@ -3829,14 +3831,12 @@ void kernel_mul_mv_iq2_tn_f32_impl( acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); } - float dall = dh[0]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy); + sumf[row] += (acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy; - qs += step; - dh += step; + qs += row_size/2; } y4 += 4 * QK_K; @@ -3845,7 +3845,7 @@ void kernel_mul_mv_iq2_tn_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = drow[row]*all_sum; } } } @@ -5504,7 +5504,7 @@ void kernel_mul_mv_iq1_tn_f32_impl( // Why are we not passing in src0->nb[0]? // But because we are not, we need to use this hack - const uint row_size = sizeof(block_iq1_tn)*(ne00/QK_K); + const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K); const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; @@ -6850,16 +6850,14 @@ void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & re template <typename type4x4> void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) { - const half d = xb->d; device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1); il = (il/2)%4; half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const half dl = d * coef; for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - d; + reg[i/4][i%4] = coef * (q[i] & mask) - 1; } } @@ -7446,30 +7444,28 @@ struct DefaultDequantizer { short il; }; -template <typename T4x4> -struct DequantizerIQ1TN { +template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&)> +struct DequantizerRS{ using type4x4 = T4x4; - using Block = block_iq1_bn;; - DequantizerIQ1TN(device const char * cx, short il = 0) : il(il) { - d = *(device const half *)cx; - x = (device const Block *)(cx + sizeof(half)); + DequantizerRS(device const char * cx, short il = 0) : il(il) { + d = *(device const Scale *)cx; + x = (device const Block *)(cx + sizeof(Scale)); } inline void convert(thread T4x4& t) const { - dequantize_iq1_bn(x, il, t); + dequantize(x, il, t); t *= d; } inline void convert(int64_t ind, thread T4x4& t) { - dequantize_iq1_bn(x + ind/4, ind%4, t); + dequantize(x + ind/nl, ind%nl, t); t *= d; } inline void next() { - constexpr int short nl = 4; il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2+nl-1)/nl : x; } device const Block * x; short il; - half d; + Scale d; }; // each block_q contains 16*nl weights @@ -7821,7 +7817,6 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>; -template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_tn, QK_NL, dequantize_iq2_tn>; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>; @@ -7842,7 +7837,8 @@ template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>; template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>; -template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerIQ1TN<float4x4>>; +template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; +template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; // // matrix-matrix multiplication @@ -7862,7 +7858,6 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>; -template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>>; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>>; @@ -7883,7 +7878,8 @@ template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_bn, 4, dequantize_iq1_bn>>; template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_bn, 4, dequantize_iq2_bn>>; -template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerIQ1TN<half4x4>>; +template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; +template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; // // indirect matrix-matrix multiplication @@ -7900,7 +7896,6 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>; -template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_K, QK_NL, dequantize_q4_K>>; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_K, QK_NL, dequantize_q5_K>>; @@ -7921,7 +7916,8 @@ template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>; template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>; template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; -template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerIQ1TN<half4x4>>; +template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>; +template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>; // // matrix-vector multiplication diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index a9a25761..b0e70bcc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14798,7 +14798,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte return false; } - if (nbytes % ggml_type_size(type) != 0) { + if (type != GGML_TYPE_IQ2_TN && type != GGML_TYPE_IQ1_TN && nbytes % ggml_type_size(type) != 0) { fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type)); return false; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 08b292b7..2804accd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -651,24 +651,28 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = 1, .type_size = sizeof(int8_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_I16] = { .type_name = "i16", .blck_size = 1, .type_size = sizeof(int16_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_I32] = { .type_name = "i32", .blck_size = 1, .type_size = sizeof(int32_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_I64] = { .type_name = "i64", .blck_size = 1, .type_size = sizeof(int64_t), .is_quantized = false, + .row_meta_size = 0, }, [GGML_TYPE_F64] = { .type_name = "f64", @@ -676,6 +680,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(double), .is_quantized = false, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_F32] = { .type_name = "f32", @@ -685,6 +690,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, .vec_dot_type = GGML_TYPE_F32, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_F16] = { .type_name = "f16", @@ -697,6 +703,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", @@ -717,6 +724,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .row_meta_size = 0, }, [GGML_TYPE_Q4_1] = { .type_name = "q4_1", @@ -733,6 +741,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .row_meta_size = 0, }, [4] = { // GGML_TYPE_Q4_2 .type_name = "DEPRECATED", @@ -745,6 +754,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, + .row_meta_size = 0, }, [5] = { // GGML_TYPE_Q4_3 .type_name = "DEPRECATED", @@ -757,6 +767,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q5_0] = { .type_name = "q5_0", @@ -773,6 +784,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, #endif .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q5_1] = { .type_name = "q5_1", @@ -785,6 +797,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q8_0] = { .type_name = "q8_0", @@ -806,6 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .nrows = 1, #endif + .row_meta_size = 0, }, [GGML_TYPE_Q8_1] = { .type_name = "q8_1", @@ -816,6 +830,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -828,6 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q3_K] = { .type_name = "q3_K", @@ -840,6 +856,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q4_K] = { .type_name = "q4_K", @@ -852,6 +869,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q5_K] = { .type_name = "q5_K", @@ -864,6 +882,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q6_K] = { .type_name = "q6_K", @@ -876,6 +895,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_XXS] = { .type_name = "iq2_xxs", @@ -888,6 +908,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_XS] = { .type_name = "iq2_xs", @@ -900,6 +921,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ3_XXS] = { .type_name = "iq3_xxs", @@ -912,6 +934,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ3_S] = { .type_name = "iq3_s", @@ -924,6 +947,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_S] = { .type_name = "iq2_s", @@ -936,6 +960,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", @@ -948,6 +973,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ1_M] = { .type_name = "iq1_m", @@ -960,6 +986,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ1_BN] = { .type_name = "iq1_bn", @@ -972,6 +999,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq1_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_BN] = { .type_name = "iq2_bn", @@ -984,6 +1012,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_TN] = { .type_name = "iq2_tn", @@ -996,6 +1025,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq2_tn_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 4, }, [GGML_TYPE_IQ1_TN] = { .type_name = "iq1_tn", @@ -1008,6 +1038,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq1_tn_q8_k, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, + .row_meta_size = 2, }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", @@ -1020,6 +1051,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ4_XS] = { .type_name = "iq4_xs", @@ -1032,6 +1064,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", @@ -1039,6 +1072,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_K), .is_quantized = true, .from_float = quantize_row_q8_K, + .row_meta_size = 0, }, [GGML_TYPE_Q8_K64] = { .type_name = "q8_K64", @@ -1046,6 +1080,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_K64), .is_quantized = true, .from_float = quantize_row_q8_K64, + .row_meta_size = 0, }, [GGML_TYPE_BF16] = { .type_name = "bf16", @@ -1058,6 +1093,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", @@ -1074,6 +1110,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 4, .gemv = ggml_gemv_q4_0_4x4_q8_0, .gemm = ggml_gemm_q4_0_4x4_q8_0, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0_4_8] = { .type_name = "q4_0_4x8", @@ -1090,6 +1127,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 4, .gemv = ggml_gemv_q4_0_4x8_q8_0, .gemm = ggml_gemm_q4_0_4x8_q8_0, + .row_meta_size = 0, }, [GGML_TYPE_Q4_0_8_8] = { .type_name = "q4_0_8x8", @@ -1106,6 +1144,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, + .row_meta_size = 0, }, [GGML_TYPE_IQ2_K] = { .type_name = "iq2_k", @@ -1118,6 +1157,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq2_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", @@ -1130,6 +1170,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq3_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ4_K] = { .type_name = "iq4_k", @@ -1142,6 +1183,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq4_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ5_K] = { .type_name = "iq5_k", @@ -1154,6 +1196,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq5_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, [GGML_TYPE_IQ6_K] = { .type_name = "iq6_k", @@ -1166,6 +1209,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq6_k_q8_k, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, + .row_meta_size = 0, }, }; @@ -3585,6 +3629,10 @@ GGML_CALL int64_t ggml_nrows(const struct ggml_tensor * tensor) { return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } +GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { + return type_traits[type].blck_size; +} + GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t nbytes; size_t blck_size = ggml_blck_size(tensor->type); @@ -3595,7 +3643,7 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { } } else { - nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + nbytes = tensor->nb[1]; //tensor->ne[0]*tensor->nb[0]/blck_size; for (int i = 1; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } @@ -3608,17 +3656,13 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } -GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { - return type_traits[type].blck_size; -} - GGML_CALL size_t ggml_type_size(enum ggml_type type) { return type_traits[type].type_size; } GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) { assert(ne % ggml_blck_size(type) == 0); - return ggml_type_size(type)*ne/ggml_blck_size(type); + return type_traits[type].row_meta_size + ggml_type_size(type)*ne/ggml_blck_size(type); } double ggml_type_sizef(enum ggml_type type) { @@ -3764,7 +3808,7 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { return false; } - next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); + next_nb = ggml_row_size(tensor->type, tensor->ne[0]); //next_nb*tensor->ne[0]/ggml_blck_size(tensor->type) + type_traits[tensor->type].row_meta_size; for (int i = 1; i < GGML_MAX_DIMS; i++) { if (tensor->ne[i] != 1) { if (i > n) { @@ -4227,7 +4271,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( } result->nb[0] = ggml_type_size(type); - result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type)); + result->nb[1] = ggml_row_size(type, ne[0]); for (int i = 2; i < GGML_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -13023,8 +13067,8 @@ static void ggml_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { if (counter++ % nth == ith) { if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), 0, 1)) goto IQK_MulMat_Not_Available1; } @@ -13036,8 +13080,8 @@ static void ggml_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), ith, nth)) goto IQK_MulMat_Not_Available1; return; @@ -13123,8 +13167,8 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), ith, nth)) goto IQK_MulMat_Not_Available2; return; @@ -13353,8 +13397,8 @@ static void ggml_compute_forward_mul_mat_id( #if GGML_USE_IQK_MULMAT if (ne13 == 1 && dst->type == GGML_TYPE_F32) { if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type), + src0->type, (const char *)src0_cur, nb01, ///ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata, row_size, ///ggml_type_size(vec_dot_type), (float *)dst->data, nb1, nb2, matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available; continue; @@ -13870,7 +13914,7 @@ static void ggml_compute_forward_softcap( default: { GGML_ASSERT(false); - } break; + } } } @@ -13986,7 +14030,7 @@ static void ggml_compute_forward_softcap_max( default: { GGML_ASSERT(false); - } break; + } } } @@ -18652,11 +18696,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_SOFTCAP: { GGML_ASSERT(false); // TODO: not implemented - } break; + } case GGML_OP_SOFT_CAP_MAX: { GGML_ASSERT(false); // TODO: not implemented - } break; + } case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7543d895..33b0a0d5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -140,8 +140,9 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, return false; } - auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); - auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); + size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); + size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); + //if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA)); auto nrc_x = (Nx + nth - 1)/nth; auto first_x = ith*nrc_x; @@ -165,8 +166,8 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } - auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); - auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); + size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); + size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); int nrc_x = (Nx + nth - 1)/nth; int first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; @@ -378,11 +379,17 @@ struct ScaleIQ4XS { const __m128i m32 = _mm_set1_epi16(-32); }; -template <typename Block> +template <typename Block, bool per_row_scale = false> struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} inline void new_row(int ix) { - x = (const Block *)((const char *)vx + bx*ix); + if constexpr (per_row_scale) { + const float * dptr = (const float *)((const char *)vx + bx*ix); + d = *dptr; + x = (const Block *)(dptr + 1); + } else { + x = (const Block *)((const char *)vx + bx*ix); + } } const void * vx; @@ -700,14 +707,13 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template <typename Q8> inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { new_block(i); } inline void new_block(int i) { - d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); } Q2Bits bits; @@ -1158,7 +1164,7 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D deq1.new_block(i); deq2.new_block(i); - float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows + //float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows for (int iy = 0; iy < nrc_y; ++iy) { auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); @@ -1176,7 +1182,7 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q); sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q); // The scale is supposed to be per per tensor, so we can use the same scale - auto vd = _mm512_set1_ps(d*q8.scale(iy, i)); + auto vd = _mm512_set1_ps(/*d* */q8.scale(iy, i)); accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); // Leaving this here just in case ternary models start using per row scales @@ -1187,8 +1193,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0])); - info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1])); + info.store(ix+0, iy, deq1.d*_mm512_reduce_add_ps(accd[2*iy+0])); + info.store(ix+1, iy, deq2.d*_mm512_reduce_add_ps(accd[2*iy+1])); } } @@ -4104,14 +4110,23 @@ struct Q2bits { } }; -template <typename block_q> +template <typename block_q, bool has_row_scale = false> struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} - inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); } + inline void new_row(int ix) { + if constexpr (has_row_scale) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + d = *dptr; + x = (const block_q *)(dptr + 1); + } else { + x = (const block_q *)((const char *)vx + ix*bx); + } + } const void * vx; const block_q * x; const size_t bx; const int nrc; + float d; }; struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> { @@ -4133,7 +4148,6 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> { Q4bits bits; Scales8 s8; - float d; }; struct HighBit5 { @@ -4202,7 +4216,6 @@ struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> { uint8x16x2_t hbits; - float d; }; inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { @@ -4256,7 +4269,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { const uint8x16_t mhb = vdupq_n_u8(0x30); - float d; }; struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> { @@ -4317,7 +4329,6 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> { uint8x16_t mask; HighBit3 h; - float d; }; struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { @@ -4389,7 +4400,6 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { Q2bits bits; - float d; }; // ============================= i-quants @@ -4453,7 +4463,6 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { const int8x16_t values; const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); - float d; }; struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { @@ -4503,7 +4512,6 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { const uint8x16_t hm = vdupq_n_u8(0x10); uint8x16x2_t hbits; - float d; }; struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { @@ -4538,7 +4546,6 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { const int8x16x4_t values; const uint8x16_t hm = vdupq_n_u8(0x30); - float d; }; struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { @@ -4570,7 +4577,6 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1)); const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); - float d; }; struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { @@ -4630,7 +4636,6 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010}); const uint8x16_t sign_shuffle = load_sign_shuffle(); - float d; }; struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { @@ -4688,7 +4693,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602}; - float d; }; struct SimpleBits { @@ -4747,7 +4751,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { uint32x4x4_t data; SimpleBits bits; - float d; }; inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { @@ -4793,7 +4796,6 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> { SimpleBits bits; - float d; }; @@ -4857,7 +4859,6 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { SimpleBits bits; SignHelper sh; - float d; }; @@ -4891,8 +4892,6 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> { SimpleBits bits; uint32x4x2_t gas; - float d; - }; struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { @@ -4951,11 +4950,9 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { SignHelper sh; uint32x4x2_t gas; - float d; - }; -struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> { DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } @@ -4966,9 +4963,7 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { // d = GGML_FP16_TO_FP32(x[i].d); //} - inline void new_block(int i) { - d = GGML_FP16_TO_FP32(x[i].d); - } + inline void new_block(int) { } template <typename Q8> inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { @@ -5019,8 +5014,6 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { } Q2bits bits; - - float d; }; template <int nrc_y> diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 42441584..28bad18e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1972,15 +1972,15 @@ void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { auto quantize = [] (float xmax, float x) { return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2; }; + int n = k; + float max = x[0]; + for (int j = 1; j < n; ++j) max = std::max(max, fabsf(x[j])); + + *(float *)y = max; + y = (block_iq2_tn *)((float *)y + 1); for (int ibl = 0; ibl < nb; ++ibl) { auto xb = x + QK_K*ibl; - float max = xb[0]; - for (int j = 0; j < QK_K; ++j) { - float ax = fabsf(xb[j]); - max = std::max(ax, max); - } - y[ibl].d = GGML_FP32_TO_FP16(max); auto qs = y[ibl].qs; for (int l = 0; l < QK_K/128; ++l) { for (int j = 0; j < 32; ++j) { @@ -1992,7 +1992,7 @@ void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { } } -void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { +void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k); } @@ -2009,9 +2009,11 @@ size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) { GGML_ASSERT(k%QK_K == 0); + const float * dptr = (const float *)x; + float d = *dptr; + x = (const block_iq2_tn *)(dptr + 1); int nb = k/QK_K; for (int ibl = 0; ibl < nb; ++ibl) { - float d = GGML_FP16_TO_FP32(x[ibl].d); auto qs = x[ibl].qs; for (int l = 0; l < QK_K/128; ++l) { for (int j = 0; j < 32; ++j) { @@ -2039,13 +2041,14 @@ void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t const int nb = n / QK_K; - const block_iq2_tn * x = (const block_iq2_tn *)vx; + const float * dptr = (const float *)vx; + const float d = *dptr; + const block_iq2_tn * x = (const block_iq2_tn *)(dptr + 1); const block_q8_K * y = (const block_q8_K *)vy; float sumf = 0; for (int i = 0; i < nb; i++) { - float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; auto qs = x[i].qs; auto q8 = y[i].qs; int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0; diff --git a/src/llama.cpp b/src/llama.cpp index 0eea948a..df57c071 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4498,9 +4498,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.6875 bpw TriLM"; + case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.625 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.06 bpw TriLM"; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.00 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; |