summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xconvert-hf-to-gguf.py2
-rw-r--r--ggml-cuda.cu6
-rw-r--r--ggml-cuda/common.cuh2
-rw-r--r--ggml-cuda/convert.cu74
-rw-r--r--ggml-cuda/convert.cuh2
-rw-r--r--ggml-cuda/dequantize.cuh10
-rw-r--r--ggml-cuda/dmmv.cu6
-rw-r--r--ggml-cuda/quantize.cu16
-rw-r--r--ggml-cuda/quantize.cuh2
-rw-r--r--ggml-quants.c310
-rw-r--r--ggml-quants.h164
-rw-r--r--ggml.c16
-rw-r--r--ggml.h14
-rw-r--r--gguf-py/gguf/constants.py2
-rw-r--r--gguf-py/gguf/tensor_mapping.py2
-rw-r--r--llama.cpp64
16 files changed, 366 insertions, 326 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 7e601170..37af6328 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -160,7 +160,7 @@ class Model(ABC):
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
- if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index ce28cb55..bff8ad9d 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
- int ldc = id == ctx.device ? ne0 : row_diff;
+ int64_t ldc = id == ctx.device ? ne0 : row_diff;
const int compute_capability = ggml_cuda_info().devices[id].cc;
@@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
- const int nb2 = dst->nb[2];
- const int nb3 = dst->nb[3];
+ const int64_t nb2 = dst->nb[2];
+ const int64_t nb3 = dst->nb[3];
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index b98d7cbd..481065b2 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
// TODO: move to ggml-common.h
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
//////////////////////
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index 18a31edc..ed4fa274 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -4,14 +4,14 @@
#define CUDA_Q8_0_NE_ALIGN 2048
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
- const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+ const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= k) {
return;
}
- const int ib = i/qk; // block index
+ const int64_t ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
template <bool need_check>
-static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
#if __CUDA_ARCH__ >= CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
- const int i = blockIdx.x;
+ const int64_t i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
- const int ib = 8*i + ir;
+ const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
template<typename dst_t>
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
- const int i = blockIdx.x;
+ const int64_t i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
- const int ib = 8*i + ir;
+ const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
@@ -313,14 +313,14 @@ template<typename dst_t>
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;
- const int i = blockIdx.x;
+ const int64_t i = blockIdx.x;
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
- const int tid = threadIdx.x;
- const int ip = tid/32; // ip is 0 or 1
- const int il = tid - 32*ip; // 0...32
- const int is = 8*ip + il/16;
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid/32; // ip is 0 or 1
+ const int64_t il = tid - 32*ip; // 0...32
+ const int64_t is = 8*ip + il/16;
dst_t * y = yy + i*QK_K + 128*ip + il;
@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
#else
// assume 32 threads
- const int tid = threadIdx.x;
- const int ip = tid/16; // 0 or 1
- const int il = tid - 16*ip; // 0...15
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid/16; // 0 or 1
+ const int64_t il = tid - 16*ip; // 0...15
dst_t * y = yy + i*QK_K + 16*ip + il;
@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
#endif
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
-static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
}
template<typename dst_t>
-static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
-static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
}
template<typename dst_t>
-static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
-static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
}
template <typename src_t, typename dst_t>
-static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+ const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
}
template <typename src_t, typename dst_t>
-static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
diff --git a/ggml-cuda/convert.cuh b/ggml-cuda/convert.cuh
index db34c0be..5394be9f 100644
--- a/ggml-cuda/convert.cuh
+++ b/ggml-cuda/convert.cuh
@@ -3,7 +3,7 @@
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
template<typename T>
-using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh
index b5440063..bd3c2d9d 100644
--- a/ggml-cuda/dequantize.cuh
+++ b/ggml-cuda/dequantize.cuh
@@ -1,6 +1,6 @@
#include "common.cuh"
-static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
const dfloat d = x[ib].d;
@@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
#endif // GGML_CUDA_F16
}
-static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
@@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
#endif // GGML_CUDA_F16
}
-static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
const dfloat d = x[ib].d;
@@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
#endif // GGML_CUDA_F16
}
-static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
@@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
#endif // GGML_CUDA_F16
}
-static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
const dfloat d = x[ib].d;
diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu
index 0b17e3cb..7313e3e1 100644
--- a/ggml-cuda/dmmv.cu
+++ b/ggml-cuda/dmmv.cu
@@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
}
}
-static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;
// automatic half -> float type cast if dfloat == float
@@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
@@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
- const int ib = (row*ncols + col)/qk; // x block index
+ const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index
diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu
index a1fbc993..7578c4b6 100644
--- a/ggml-cuda/quantize.cu
+++ b/ggml-cuda/quantize.cu
@@ -1,20 +1,20 @@
#include "quantize.cuh"
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
- const int ix = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
+ const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (ix >= kx_padded) {
return;
}
- const int iy = blockDim.y*blockIdx.y + threadIdx.y;
+ const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
- const int i_padded = iy*kx_padded + ix;
+ const int64_t i_padded = (int64_t)iy*kx_padded + ix;
block_q8_1 * y = (block_q8_1 *) vy;
- const int ib = i_padded / QK8_1; // block index
- const int iqs = i_padded % QK8_1; // quant index
+ const int64_t ib = i_padded / QK8_1; // block index
+ const int64_t iqs = i_padded % QK8_1; // quant index
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
float amax = fabsf(xi);
@@ -36,8 +36,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
- const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
+ const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ky, 1);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
diff --git a/ggml-cuda/quantize.cuh b/ggml-cuda/quantize.cuh
index adb89c83..b37a4752 100644
--- a/ggml-cuda/quantize.cuh
+++ b/ggml-cuda/quantize.cuh
@@ -2,4 +2,4 @@
#define CUDA_QUANTIZE_BLOCK_SIZE 256
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream);
+void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);
diff --git a/ggml-quants.c b/ggml-quants.c
index f2e6c4bd..32e84434 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -544,7 +544,7 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#endif
// reference implementation for deterministic creation of model files
-void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
assert(k % qk == 0);
@@ -581,12 +581,12 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
}
}
-void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_0_reference(x, y, k);
}
-void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
const int qk = QK4_1;
assert(k % qk == 0);
@@ -623,11 +623,11 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
}
}
-void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_1_reference(x, y, k);
}
-void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
static const int qk = QK5_0;
assert(k % qk == 0);
@@ -671,11 +671,11 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
}
}
-void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_0_reference(x, y, k);
}
-void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
const int qk = QK5_1;
assert(k % qk == 0);
@@ -719,12 +719,12 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
}
}
-void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_1_reference(x, y, k);
}
// reference implementation for deterministic creation of model files
-void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
@@ -749,7 +749,7 @@ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict
}
}
-void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
@@ -938,7 +938,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
}
// reference implementation for deterministic creation of model files
-void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
assert(QK8_1 == 32);
assert(k % QK8_1 == 0);
const int nb = k / QK8_1;
@@ -973,7 +973,7 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict
}
}
-void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK8_1 == 0);
const int nb = k / QK8_1;
@@ -1192,7 +1192,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
#endif
}
-void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK4_0;
assert(k % qk == 0);
@@ -1212,7 +1212,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int
}
}
-void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK4_1;
assert(k % qk == 0);
@@ -1233,7 +1233,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int
}
}
-void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK5_0;
assert(k % qk == 0);
@@ -1259,7 +1259,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int
}
}
-void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK5_1;
assert(k % qk == 0);
@@ -1286,7 +1286,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int
}
}
-void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK8_0;
assert(k % qk == 0);
@@ -1581,7 +1581,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
//========================- 2-bit (de)-quantization
-void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -1658,7 +1658,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
}
}
-void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -1704,7 +1704,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
}
}
-void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q2_K_reference(x, vy, k);
}
@@ -1960,14 +1960,14 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
}
}
-size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
if (!quant_weights) {
- quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
+ quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -1978,7 +1978,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow,
//========================= 3-bit (de)-quantization
-void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -2092,7 +2092,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
}
#if QK_K == 256
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -2142,7 +2142,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
}
}
#else
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
assert(QK_K == 64);
const int nb = k / QK_K;
@@ -2175,11 +2175,11 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
}
#endif
-void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q3_K_reference(x, vy, k);
}
-static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) {
+static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q3_K_reference(x, y, n_per_row);
@@ -2268,14 +2268,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
#endif
}
-size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
if (!quant_weights) {
- quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
+ quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -2286,7 +2286,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow,
// ====================== 4-bit (de)-quantization
-void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -2393,7 +2393,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
}
}
-void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -2432,19 +2432,19 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
}
}
-void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_q4_K * restrict y = vy;
quantize_row_q4_K_reference(x, y, k);
}
-static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q4_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
- const int nb = n_per_row / QK_K;
+ const int64_t nb = n_per_row / QK_K;
uint8_t L[QK_K];
uint8_t Laux[32];
@@ -2516,14 +2516,14 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
#endif
}
-size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
if (!quant_weights) {
- quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
+ quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -2534,9 +2534,9 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow,
// ====================== 5-bit (de)-quantization
-void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
#if QK_K == 256
uint8_t L[QK_K];
@@ -2676,9 +2676,9 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
}
}
-void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
@@ -2721,19 +2721,19 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
}
}
-void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_q5_K * restrict y = vy;
quantize_row_q5_K_reference(x, y, k);
}
-static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q5_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
- const int nb = n_per_row / QK_K;
+ const int64_t nb = n_per_row / QK_K;
uint8_t L[QK_K];
uint8_t Laux[32];
@@ -2825,14 +2825,14 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
#endif
}
-size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
if (!quant_weights) {
- quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
+ quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -2843,9 +2843,9 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow,
// ====================== 6-bit (de)-quantization
-void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
int8_t L[QK_K];
float scales[QK_K/16];
@@ -2925,9 +2925,9 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
}
}
-void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
@@ -2972,19 +2972,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
}
}
-void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_q6_K * restrict y = vy;
quantize_row_q6_K_reference(x, y, k);
}
-static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q6_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
- const int nb = n_per_row / QK_K;
+ const int64_t nb = n_per_row / QK_K;
int8_t L[QK_K];
float scales[QK_K/16];
@@ -3067,14 +3067,14 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
#endif
}
-size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
if (!quant_weights) {
- quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
+ quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -3083,7 +3083,7 @@ size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow,
return nrow * row_size;
}
-static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
static_assert(QK4_0 == 32, "QK4_0 must be 32");
if (!quant_weights) {
@@ -3098,7 +3098,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
float sigma2 = sum_x2/n_per_row;
- const int nb = n_per_row/QK4_0;
+ const int64_t nb = n_per_row/QK4_0;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK4_0 * ib;
const float * qw = quant_weights + QK4_0 * ib;
@@ -3111,14 +3111,14 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
}
}
-size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
- quantize_row_q4_0_reference(src, dst, nrow*n_per_row);
+ quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -3126,7 +3126,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow,
return nrow * row_size;
}
-static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
static_assert(QK4_1 == 32, "QK4_1 must be 32");
if (!quant_weights) {
@@ -3141,7 +3141,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
float sigma2 = sum_x2/n_per_row;
- const int nb = n_per_row/QK4_1;
+ const int64_t nb = n_per_row/QK4_1;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK4_1 * ib;
const float * qw = quant_weights + QK4_1 * ib;
@@ -3156,14 +3156,14 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
}
}
-size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
- quantize_row_q4_1_reference(src, dst, nrow*n_per_row);
+ quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -3171,7 +3171,7 @@ size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow,
return nrow * row_size;
}
-static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
static_assert(QK5_0 == 32, "QK5_0 must be 32");
if (!quant_weights) {
@@ -3186,7 +3186,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
float sigma2 = sum_x2/n_per_row;
- const int nb = n_per_row/QK5_0;
+ const int64_t nb = n_per_row/QK5_0;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK5_0 * ib;
const float * qw = quant_weights + QK5_0 * ib;
@@ -3210,14 +3210,14 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
}
}
-size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
- quantize_row_q5_0_reference(src, dst, nrow*n_per_row);
+ quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -3225,7 +3225,7 @@ size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow,
return nrow * row_size;
}
-static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
static_assert(QK5_1 == 32, "QK5_1 must be 32");
if (!quant_weights) {
@@ -3240,7 +3240,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
float sigma2 = sum_x2/n_per_row;
- const int nb = n_per_row/QK5_1;
+ const int64_t nb = n_per_row/QK5_1;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK5_1 * ib;
const float * qw = quant_weights + QK5_1 * ib;
@@ -3263,14 +3263,14 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
}
}
-size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
- quantize_row_q5_1_reference(src, dst, nrow*n_per_row);
+ quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
@@ -3278,18 +3278,18 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow,
return nrow * row_size;
}
-size_t quantize_q8_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
- quantize_row_q8_0_reference(src, dst, nrow*n_per_row);
+ quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size;
}
// ====================== "True" 2-bit (de)-quantization
-void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
@@ -3315,9 +3315,9 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y
// ====================== 2.3125 bpw (de)-quantization
-void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
float db[2];
@@ -3342,9 +3342,9 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
// ====================== 2.5625 bpw (de)-quantization
-void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int k) {
+void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
float db[2];
@@ -3374,9 +3374,9 @@ void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, in
// ====================== 3.0625 bpw (de)-quantization
-void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
uint32_t aux32;
@@ -3406,9 +3406,9 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
// ====================== 3.3125 bpw (de)-quantization
-void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
+void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
@@ -3449,9 +3449,9 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
// ====================== 1.5625 bpw (de)-quantization
-void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
+void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
@@ -3474,9 +3474,9 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
}
}
-void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) {
+void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
float delta[4];
uint16_t idx[4];
@@ -3535,9 +3535,9 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
-void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
+void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) {
assert(k % QK4_NL == 0);
- const int nb = k / QK4_NL;
+ const int64_t nb = k / QK4_NL;
for (int i = 0; i < nb; i++) {
@@ -3553,12 +3553,12 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
}
}
-void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
#if QK_K == 64
dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k);
#else
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
@@ -3582,9 +3582,9 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
//===================================== Q8_K ==============================================
-void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
@@ -3621,9 +3621,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
}
}
-void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) {
assert(k % QK_K == 0);
- const int nb = k / QK_K;
+ const int64_t nb = k / QK_K;
for (int i = 0; i < nb; i++) {
for (int j = 0; j < QK_K; ++j) {
@@ -3632,7 +3632,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int
}
}
-void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q8_K_reference(x, y, k);
}
@@ -10648,7 +10648,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
return grid_index;
}
-static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
@@ -10664,7 +10664,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
const int kMaxQ = 3;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
block_iq2_xxs * y = vy;
@@ -10821,7 +10821,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
}
}
-static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
@@ -10837,7 +10837,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
const int kMaxQ = 3;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
block_iq2_xs * y = vy;
@@ -11001,11 +11001,11 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
}
}
-size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq2_xxs);
@@ -11013,11 +11013,11 @@ size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nro
return nrow * nblock * sizeof(block_iq2_xxs);
}
-size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq2_xs);
@@ -11242,7 +11242,7 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
return grid_index;
}
-static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n,
+static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n,
const float * restrict quant_weights) {
const int gindex = iq3_data_index(grid_size);
@@ -11259,7 +11259,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
const int kMaxQ = 8;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
ggml_fp16_t * dh;
uint8_t * qs;
@@ -11455,11 +11455,11 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
}
}
-size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq3_xxs);
@@ -11467,13 +11467,13 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nro
return nrow * nblock * sizeof(block_iq3_xxs);
}
-void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_iq3_xxs * restrict y = vy;
quantize_row_iq3_xxs_reference(x, y, k);
}
-void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
+void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
assert(k % QK_K == 0);
quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
}
@@ -11504,7 +11504,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
const int kMaxQ = 8;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
block_iq3_s * y = vy;
@@ -11661,9 +11661,9 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
}
#define IQ3S_BLOCK_SIZE 32
-size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
float scales[QK_K/IQ3S_BLOCK_SIZE];
float weight[IQ3S_BLOCK_SIZE];
float xval[IQ3S_BLOCK_SIZE];
@@ -11674,7 +11674,7 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow,
bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
src += n_per_row;
@@ -11683,13 +11683,13 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow,
return nrow * nblock * sizeof(block_iq3_s);
}
-void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_iq3_s * restrict y = vy;
quantize_row_iq3_s_reference(x, y, k);
}
-void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) {
+void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
assert(k % QK_K == 0);
quantize_iq3_s(x, y, 1, k, NULL);
}
@@ -11822,7 +11822,7 @@ static int iq1_sort_helper(const void * left, const void * right) {
#define IQ1S_BLOCK_SIZE 32
#define IQ1M_BLOCK_SIZE 16
-static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
+static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
float * scales,
float * weight,
float * sumx,
@@ -11846,7 +11846,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
block_iq1_s * y = vy;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
const int block_size = IQ1S_BLOCK_SIZE;
@@ -11980,7 +11980,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
}
}
-size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
float scales[QK_K/IQ1S_BLOCK_SIZE];
float weight[IQ1S_BLOCK_SIZE];
@@ -11990,9 +11990,9 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow,
float pairs[2*IQ1S_BLOCK_SIZE];
uint16_t index[IQ1S_BLOCK_SIZE/8];
int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
src += n_per_row;
qrow += nblock*sizeof(block_iq1_s);
@@ -12000,7 +12000,7 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow,
return nrow * nblock * sizeof(block_iq1_s);
}
-static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
+static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
float * scales,
float * weight,
float * pairs,
@@ -12022,7 +12022,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
block_iq1_m * y = vy;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
const int block_size = IQ1M_BLOCK_SIZE;
@@ -12265,7 +12265,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
}
}
-size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
float scales[QK_K/IQ1M_BLOCK_SIZE];
float weight[IQ1M_BLOCK_SIZE];
@@ -12273,9 +12273,9 @@ size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow,
float pairs[2*IQ1M_BLOCK_SIZE];
uint16_t index[IQ1M_BLOCK_SIZE/8];
int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
src += n_per_row;
qrow += nblock*sizeof(block_iq1_m);
@@ -12407,16 +12407,16 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
}
}
-size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK4_NL == 0);
- int nblock = n_per_row/QK4_NL;
+ int64_t nblock = n_per_row/QK4_NL;
char * qrow = (char *)dst;
uint8_t L[QK4_NL];
float weight[QK4_NL];
uint16_t unused_h;
uint8_t * unused_l = NULL;
float scale;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
@@ -12429,9 +12429,9 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
return nrow * nblock * sizeof(block_iq4_nl);
}
-void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) {
GGML_ASSERT(k%QK4_NL == 0);
- int nblock = k/QK4_NL;
+ int64_t nblock = k/QK4_NL;
uint8_t L[QK4_NL];
float weight[QK4_NL];
uint16_t unused_h;
@@ -12444,22 +12444,22 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
}
}
-void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
+void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
assert(k % QK4_NL == 0);
quantize_row_iq4_nl(x, y, k);
}
-size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
#if QK_K == 64
return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights);
#else
GGML_ASSERT(n_per_row%QK_K == 0);
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
uint8_t L[QK_K];
float weight[32];
float scales[QK_K/32];
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
@@ -12473,20 +12473,20 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
#endif
}
-void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_iq4_xs * restrict y = vy;
quantize_row_iq4_xs_reference(x, y, k);
}
-void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) {
+void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
assert(k % QK_K == 0);
quantize_iq4_xs(x, y, 1, k, NULL);
}
// =============================== 2.5625 bpw
-static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);
@@ -12501,7 +12501,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
const int kMaxQ = 3;
- const int nbl = n/QK_K;
+ const int64_t nbl = n/QK_K;
block_iq2_s * y = vy;
@@ -12654,11 +12654,11 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
}
}
-size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_ASSERT(n_per_row%QK_K == 0);
- int nblock = n_per_row/QK_K;
+ int64_t nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
- for (int row = 0; row < nrow; ++row) {
+ for (int64_t row = 0; row < nrow; ++row) {
quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq2_s);
@@ -12666,12 +12666,12 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow,
return nrow * nblock * sizeof(block_iq2_s);
}
-void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int k) {
+void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
assert(k % QK_K == 0);
quantize_iq2_s(x, y, 1, k, NULL);
}
-void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0);
block_iq2_s * restrict y = vy;
quantize_row_iq2_s_reference(x, y, k);
diff --git a/ggml-quants.h b/ggml-quants.h
index ac1091c3..4d436a8f 100644
--- a/ggml-quants.h
+++ b/ggml-quants.h
@@ -12,70 +12,70 @@ extern "C" {
#endif
// Quantization
-void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int k);
-void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int k);
-void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int k);
-void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int k);
-void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int k);
-void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int k);
-
-void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k);
-void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int k);
-void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int k);
-void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int k);
-void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k);
-void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
-
-void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
-void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
-void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k);
-void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
-void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k);
-
-void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-
-void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-
-void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization
-void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-
-void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-
-void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -101,26 +101,26 @@ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
-size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-
-size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
+size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void iq2xs_init_impl(enum ggml_type type);
void iq2xs_free_impl(enum ggml_type type);
diff --git a/ggml.c b/ggml.c
index c9b0a6a0..793b67f4 100644
--- a/ggml.c
+++ b/ggml.c
@@ -338,14 +338,14 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
return GGML_FP32_TO_FP16(x);
}
-void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) {
- for (int i = 0; i < n; i++) {
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
+ for (int64_t i = 0; i < n; i++) {
y[i] = GGML_FP16_TO_FP32(x[i]);
}
}
-void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) {
- int i = 0;
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
+ int64_t i = 0;
#if defined(__F16C__)
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
@@ -20331,11 +20331,11 @@ size_t ggml_quantize_chunk(
enum ggml_type type,
const float * src,
void * dst,
- int start,
- int nrows,
- int n_per_row,
+ int64_t start,
+ int64_t nrows,
+ int64_t n_per_row,
const float * imatrix) {
- const int n = nrows * n_per_row;
+ const int64_t n = (int64_t) nrows * n_per_row;
if (ggml_quantize_requires_imatrix(type)) {
GGML_ASSERT(imatrix != NULL);
diff --git a/ggml.h b/ggml.h
index 5cef45c0..abe3767f 100644
--- a/ggml.h
+++ b/ggml.h
@@ -332,8 +332,8 @@ extern "C" {
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
- GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
- GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
+ GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n);
+ GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n);
struct ggml_object;
struct ggml_context;
@@ -2210,9 +2210,9 @@ extern "C" {
enum ggml_type type,
const float * src,
void * dst,
- int start,
- int nrows,
- int n_per_row,
+ int64_t start,
+ int64_t nrows,
+ int64_t n_per_row,
const float * imatrix);
//
@@ -2377,8 +2377,8 @@ extern "C" {
#else
#define GGML_RESTRICT restrict
#endif
- typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
- typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+ typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+ typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc);
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index c44d8abe..a6454a10 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -639,6 +639,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_Q_NORM,
],
# TODO
}
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 345b1b0c..4f02d298 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -285,12 +285,14 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
+ "model.layers.{bid}.self_attn.q_norm", # cohere
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
),
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
+ "model.layers.{bid}.self_attn.k_norm", # cohere
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
),
diff --git a/llama.cpp b/llama.cpp
index b16ddc64..6a090d1b 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -926,6 +926,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
},
},
{
@@ -5406,6 +5408,11 @@ static bool llm_load_tensors(
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ if (n_layer >= 64){
+ layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head});
+ layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv});
+ }
+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
@@ -9454,6 +9461,31 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
+ if (model.layers[il].attn_q_norm) {
+ Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+ ggml_element_size(Qcur) * n_embd_head,
+ ggml_element_size(Qcur) * n_embd_head * n_head,
+ 0);
+ cb(Qcur, "Qcur", il);
+ Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+ ggml_element_size(Kcur) * n_embd_head,
+ ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+ 0);
+ cb(Kcur, "Kcur", il);
+
+ Qcur = llm_build_norm(ctx0, Qcur, hparams,
+ model.layers[il].attn_q_norm,
+ NULL,
+ LLM_NORM, cb, il);
+ cb(Qcur, "Qcur", il);
+
+ Kcur = llm_build_norm(ctx0, Kcur, hparams,
+ model.layers[il].attn_k_norm,
+ NULL,
+ LLM_NORM, cb, il);
+ cb(Kcur, "Kcur", il);
+ }
+
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@@ -13323,9 +13355,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
return new_type;
}
-static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
+static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
std::mutex mutex;
- int counter = 0;
+ int64_t counter = 0;
size_t new_size = 0;
if (nthread < 2) {
// single-thread
@@ -13333,11 +13365,11 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
}
auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size,
nrows, n_per_row, imatrix]() {
- const int nrows_per_chunk = chunk_size / n_per_row;
+ const int64_t nrows_per_chunk = chunk_size / n_per_row;
size_t local_size = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
- int first_row = counter; counter += nrows_per_chunk;
+ int64_t first_row = counter; counter += nrows_per_chunk;
if (first_row >= nrows) {
if (local_size > 0) {
new_size += local_size;
@@ -13345,7 +13377,7 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
break;
}
lock.unlock();
- const int this_nrow = std::min(nrows - first_row, nrows_per_chunk);
+ const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
}
};
@@ -13539,6 +13571,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);
+
+ // do not quantize norm tensors
+ quantize &= name.find("_norm.weight") == std::string::npos;
+
quantize &= params->quantize_output_tensor || name != "output.weight";
quantize &= !params->only_copy;
@@ -13585,7 +13621,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
new_size = ggml_nbytes(tensor);
LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
} else {
- const size_t nelements = ggml_nelements(tensor);
+ const int64_t nelements = ggml_nelements(tensor);
const float * imatrix = nullptr;
if (imatrix_data) {
@@ -13637,20 +13673,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
fflush(stdout);
- if (work.size() < nelements * 4) {
+ if (work.size() < (size_t)nelements * 4) {
work.resize(nelements * 4); // upper bound on size
}
new_data = work.data();
- const int n_per_row = tensor->ne[0];
- const int nrows = tensor->ne[1];
+ const int64_t n_per_row = tensor->ne[0];
+ const int64_t nrows = tensor->ne[1];
- static const int min_chunk_size = 32 * 512;
- const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
+ static const int64_t min_chunk_size = 32 * 512;
+ const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
- const int nelements_matrix = tensor->ne[0] * tensor->ne[1];
- const int nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
- const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
+ const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
+ const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
+ const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
// quantize each expert separately since they have different importance matrices
new_size = 0;