summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-05-12 00:23:08 +0300
committerGitHub <noreply@github.com>2023-05-12 00:23:08 +0300
commitb9fd7eee57df101d4a3e3eabc9fd6c2cb13c9ca1 (patch)
treebbc617734354146eeffed3ba58cd1b4cfebb4aa6 /ggml-cuda.cu
parentb608b55a3ea8e4760c617418538465449175bdb8 (diff)
ggml : remove bit shuffling (#1405)
* ggml : remove Q4_0 bit shufling (ARM NEON) * ggml : remove Q4_1 bit shuffling (ARM NEON + reference) * ggml : nibbles_from_floats() + bytes_from_nibbles() (ARM NEON) * ggml : remove Q4_2 bit shuffling (WIP, BROKEN) * ggml : remove Q5_0 bit shuffling (ARM NEON) * ggml : 2x faster scalar implementations * ggml : remove Q5_1 bit shuffling (ARM NEON + scalar) * ggml : simplify scalar dot * ggml : remove WASM SIMD bit shuffling + remove vzip for ARM 32-bit * ggml : fix Q4_1 quantization * ggml : update cuBLAS + normalize variable names * ggml : remove Q4_2 mode * ggml : minor formatting * ggml : fix Q5_0 quantization * scripts : add script for measuring the time per token * AVX implementations (#1370) * ggml : uniform 5th bit extraction * llama : produce error upon loading old model files * llama : fix model magic/version write * ggml : speed-up Q5_0 + Q5_1 at 4 threads * ggml : preserve old Q4 and Q5 formats * ggml : simplify Q8_1 - no need for low / high sums anymore * ggml : fix Q8_0 and Q8_1 rounding * Revert "AVX implementations (#1370)" This reverts commit 948d124837f9d287d8490f41338e0e4cceb0814f. * ggml : fix AVX2 implementation * sha : update hashes for 7B and 13B * readme : update timings + remove warning banner * llama : update v2 PR number to 1405 * ggml : fix WASM comments * ggml : back to original bit order * readme : add note that Q4 and Q5 have been changed * llama : fix return for unknown version --------- Co-authored-by: Stephan Walter <stephan@walter.name>
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu131
1 files changed, 36 insertions, 95 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 127b352a..8a3beb0e 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -49,13 +49,6 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
-#define QK4_2 16
-typedef struct {
- half d; // delta
- uint8_t qs[QK4_2 / 2]; // nibbles / quants
-} block_q4_2;
-static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
-
#define QK5_0 32
typedef struct {
half d; // delta
@@ -81,29 +74,26 @@ typedef struct {
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
+ static const int qk = QK4_0;
+
const block_q4_0 * x = (const block_q4_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
- const uint8_t * pp = x[i].qs;
-
- for (int l = 0; l < QK4_0; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vi0 = vi & 0xf;
- const int8_t vi1 = vi >> 4;
+ for (int j = 0; j < qk/2; ++j) {
+ const int x0 = (x[i].qs[j] & 0xf) - 8;
+ const int x1 = (x[i].qs[j] >> 4) - 8;
- const float v0 = (vi0 - 8)*d;
- const float v1 = (vi1 - 8)*d;
-
- y[i*QK4_0 + l + 0] = v0;
- y[i*QK4_0 + l + 1] = v1;
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
}
}
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
+ static const int qk = QK4_1;
+
const block_q4_1 * x = (const block_q4_1 *) vx;
const int i = blockIdx.x;
@@ -111,75 +101,42 @@ static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
const float d = x[i].d;
const float m = x[i].m;
- const uint8_t * pp = x[i].qs;
-
- for (int l = 0; l < QK4_1; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vi0 = vi & 0xf;
- const int8_t vi1 = vi >> 4;
+ for (int j = 0; j < qk/2; ++j) {
+ const int x0 = (x[i].qs[j] & 0xf);
+ const int x1 = (x[i].qs[j] >> 4);
- const float v0 = vi0*d + m;
- const float v1 = vi1*d + m;
-
- y[i*QK4_1 + l + 0] = v0;
- y[i*QK4_1 + l + 1] = v1;
- }
-}
-
-static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
- const block_q4_2 * x = (const block_q4_2 *) vx;
-
- const int i = blockIdx.x;
-
- const float d = x[i].d;
-
- const uint8_t * pp = x[i].qs;
-
- for (int l = 0; l < QK4_2; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vi0 = vi & 0xf;
- const int8_t vi1 = vi >> 4;
-
- const float v0 = (vi0 - 8)*d;
- const float v1 = (vi1 - 8)*d;
-
- y[i*QK4_2 + l + 0] = v0;
- y[i*QK4_2 + l + 1] = v1;
+ y[i*qk + j + 0 ] = x0*d + m;
+ y[i*qk + j + qk/2] = x1*d + m;
}
}
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
+ static const int qk = QK5_0;
+
const block_q5_0 * x = (const block_q5_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
- const uint8_t * pp = x[i].qs;
-
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
- for (int l = 0; l < QK5_0; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
- const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
- const int8_t vi0 = ((vi & 0xf) | vh0);
- const int8_t vi1 = ((vi >> 4) | vh1);
+ const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
- const float v0 = (vi0 - 16)*d;
- const float v1 = (vi1 - 16)*d;
-
- y[i*QK5_0 + l + 0] = v0;
- y[i*QK5_0 + l + 1] = v1;
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
}
}
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
+ static const int qk = QK5_1;
+
const block_q5_1 * x = (const block_q5_1 *) vx;
const int i = blockIdx.x;
@@ -187,41 +144,32 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
const float d = x[i].d;
const float m = x[i].m;
- const uint8_t * pp = x[i].qs;
-
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
- for (int l = 0; l < QK5_1; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
- const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
- const int8_t vi0 = (vi & 0xf) | vh0;
- const int8_t vi1 = (vi >> 4) | vh1;
+ const int x0 = (x[i].qs[j] & 0xf) | xh_0;
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
- const float v0 = vi0*d + m;
- const float v1 = vi1*d + m;
-
- y[i*QK5_1 + l + 0] = v0;
- y[i*QK5_1 + l + 1] = v1;
+ y[i*qk + j + 0 ] = x0*d + m;
+ y[i*qk + j + qk/2] = x1*d + m;
}
}
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
+ static const int qk = QK8_0;
+
const block_q8_0 * x = (const block_q8_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
- const int8_t * pp = x[i].qs;
-
- for (int l = 0; l < QK8_0; l++) {
- const int8_t vi = pp[l];
-
- y[i*QK8_0 + l] = vi*d;
+ for (int j = 0; j < qk; ++j) {
+ y[i*qk + j] = x[i].qs[j]*d;
}
}
@@ -235,11 +183,6 @@ static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
}
-static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
- const int nb = k / QK4_2;
- dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
-}
-
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_0;
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
@@ -274,8 +217,6 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
- case GGML_TYPE_Q4_2:
- return dequantize_row_q4_2_cuda;
case GGML_TYPE_Q5_0:
return dequantize_row_q5_0_cuda;
case GGML_TYPE_Q5_1: