summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/convert.cu19
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu17
2 files changed, 15 insertions, 21 deletions
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 4b1be7c1..c74b030b 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -154,19 +154,23 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
}
template<typename dst_t>
-static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ int64_t n_per_row, int64_t row_size) {
- const int64_t i = blockIdx.x;
- const block_iq2_tn * x = (const block_iq2_tn *) vx;
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float d = *(const float *)cx;
+ const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float));
+ int64_t i = ii - (row*n_per_row)/QK_K;
const int64_t tid = threadIdx.x;
const int64_t n = tid/32;
const int64_t l = tid - 32*n;
const uint8_t q = x[i].qs[32*n + l];
- dst_t * y = yy + i*QK_K + 128*n;
+ dst_t * y = yy + ii*QK_K + 128*n;
- float d = __half2float(x[i].d);
y[l+ 0] = d * ((q >> 0) & 3) - d;
y[l+32] = d * ((q >> 2) & 3) - d;
y[l+64] = d * ((q >> 4) & 3) - d;
@@ -743,8 +747,9 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n
template<typename dst_t>
static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
- const int nb = k / QK_K;
- dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row);
+ const int nb = (k + 255) / 256;
+ dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size);
}
template<typename dst_t>
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index a890f6b3..b2c32c0c 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -519,7 +519,8 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx;
+ float scale = *(const float *)vbq;
+ const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx;
const int bq8_offset = QR2_K * (iqs / QI8_1);
@@ -533,19 +534,7 @@ static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
v >>= 2;
}
- return __half2float(bq2->d) * sumf;
-
- //float sumf_d = 0;
- //float sumf_m = 0;
- //for (int i = 0; i < QR2_K; ++ i) {
- // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
- // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds);
- // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0);
- // sumf_m += d8.y;
- // v >>= 2;
- //}
- //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m);
-
+ return scale * sumf;
}
static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(