summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/llama-bench/llama-bench.cpp106
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp494
2 files changed, 434 insertions, 166 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 75fe40d1..b46bd855 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -220,6 +220,7 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<std::pair<int, int>> n_pg;
+ std::vector<std::pair<int, int>> n_gp;
std::vector<int> n_batch;
std::vector<int> n_ubatch;
std::vector<ggml_type> type_k;
@@ -248,6 +249,7 @@ static const cmd_params cmd_params_defaults = {
/* n_prompt */ {512},
/* n_gen */ {128},
/* n_pg */ {},
+ /* n_gp */ {},
/* n_batch */ {2048},
/* n_ubatch */ {512},
/* type_k */ {GGML_TYPE_F16},
@@ -280,6 +282,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -pg <pp,tg> (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
+ printf(" -gp <pp,tg> (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_gp, pair_str), ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" -ub, --ubatch-size <n> (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str());
printf(" -ctk, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
@@ -393,6 +396,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.n_pg.push_back({std::stoi(p[0]), std::stoi(p[1])});
+ } else if (arg == "-gp") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = string_split<std::string>(argv[i], ',');
+ if (p.size() != 2) {
+ invalid_param = true;
+ break;
+ }
+ params.n_gp.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
@@ -596,6 +610,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; }
+ if (params.n_gp.empty()) { params.n_gp = cmd_params_defaults.n_gp; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; }
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
@@ -614,7 +629,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
return params;
}
+enum test_kind_type {
+ // measure mean prompt processing rate without token generation
+ TEST_KIND_PP,
+ // measure mean token generation rate without prompt processing
+ TEST_KIND_TG,
+ // measure mean prompt processing and token generation rate
+ TEST_KIND_PG,
+ // measure mean token generation rate after processing prompt of given length
+ TEST_KIND_GP,
+};
+
struct cmd_params_instance {
+ test_kind_type test_kind;
std::string model;
int n_prompt;
int n_gen;
@@ -701,6 +728,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
+ /* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
@@ -728,6 +756,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
+ /* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
@@ -755,6 +784,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
+ /* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ n_pg.first,
/* .n_gen = */ n_pg.second,
@@ -776,6 +806,34 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
};
instances.push_back(instance);
}
+
+ for (const auto & n_gp : params.n_gp) {
+ if (n_gp.first == 0 && n_gp.second == 0) {
+ continue;
+ }
+ cmd_params_instance instance = {
+ /* .test_kind = */ TEST_KIND_GP,
+ /* .model = */ m,
+ /* .n_prompt = */ n_gp.first,
+ /* .n_gen = */ n_gp.second,
+ /* .n_batch = */ nb,
+ /* .n_ubatch = */ nub,
+ /* .type_k = */ tk,
+ /* .type_v = */ tv,
+ /* .n_threads = */ nt,
+ /* .n_gpu_layers = */ nl,
+ /* .rpc_servers = */ rpc,
+ /* .split_mode = */ sm,
+ /* .main_gpu = */ mg,
+ /* .no_kv_offload= */ nkvo,
+ /* .flash_attn = */ fa,
+ /* .tensor_split = */ ts,
+ /* .use_mmap = */ mmp,
+ /* .embeddings = */ embd,
+ /* .repack = */ params.repack,
+ };
+ instances.push_back(instance);
+ }
}
return instances;
@@ -816,6 +874,8 @@ struct test {
int n_gen;
std::string test_time;
std::vector<uint64_t> samples_ns;
+ test_kind_type test_kind;
+ std::string test_label;
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
model_filename = inst.model;
@@ -841,11 +901,32 @@ struct test {
repack = inst.repack;
n_prompt = inst.n_prompt;
n_gen = inst.n_gen;
+ test_kind = inst.test_kind;
// RFC 3339 date-time format
time_t t = time(NULL);
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;
+ // prepare test label for printing
+ switch (test_kind) {
+ case TEST_KIND_PP:
+ snprintf(buf, sizeof(buf), "pp%d", n_prompt);
+ break;
+ case TEST_KIND_TG:
+ snprintf(buf, sizeof(buf), "tg%d", n_gen);
+ break;
+ case TEST_KIND_PG:
+ snprintf(buf, sizeof(buf), "pp%d+tg%d", n_prompt, n_gen);
+ break;
+ case TEST_KIND_GP:
+ snprintf(buf, sizeof(buf), "tg%d@pp%d", n_gen, n_prompt);
+ break;
+ default:
+ snprintf(buf, sizeof(buf), "unknown");
+ break;
+ }
+ test_label = buf;
+
(void) ctx;
}
@@ -858,7 +939,7 @@ struct test {
}
std::vector<double> get_ts() const {
- int n_tokens = n_prompt + n_gen;
+ int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
std::vector<double> ts;
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
return ts;
@@ -911,7 +992,7 @@ struct test {
"tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
- "avg_ts", "stddev_ts"
+ "avg_ts", "stddev_ts", "test",
};
return fields;
}
@@ -967,7 +1048,8 @@ struct test {
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
- std::to_string(avg_ts()), std::to_string(stdev_ts())
+ std::to_string(avg_ts()), std::to_string(stdev_ts()),
+ test_label
};
return values;
}
@@ -1269,14 +1351,15 @@ struct markdown_printer : public printer {
value += "+RPC";
}
} else if (field == "test") {
- if (t.n_prompt > 0 && t.n_gen == 0) {
- snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
- } else if (t.n_gen > 0 && t.n_prompt == 0) {
- snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
- } else {
- snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
- }
- value = buf;
+ //if (t.n_prompt > 0 && t.n_gen == 0) {
+ // snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
+ //} else if (t.n_gen > 0 && t.n_prompt == 0) {
+ // snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
+ //} else {
+ // snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
+ //}
+ //value = buf;
+ value = t.test_label;
} else if (field == "t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
value = buf;
@@ -1489,6 +1572,7 @@ int main(int argc, char ** argv) {
if (t.n_prompt > 0) {
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
+ if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns();
if (t.n_gen > 0) {
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
}
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 8d2b4090..308d0dca 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -111,6 +111,15 @@ struct Perf {
#define IQK_ALWAYS_INLINE __attribute__((__always_inline__))
#endif
+#if defined __x86_64__
+#if defined HAVE_FANCY_SIMD
+ #undef HAVE_FANCY_SIMD
+#endif
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
+ #define HAVE_FANCY_SIMD
+#endif
+#endif
+
namespace {
typedef struct {
@@ -236,6 +245,35 @@ struct MulMat {
}
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
static inline int num_rows(ggml_type type) {
+#ifdef HAVE_FANCY_SIMD
+ switch (type) {
+ case GGML_TYPE_Q2_K_R4:
+ case GGML_TYPE_Q3_K_R4:
+ case GGML_TYPE_Q6_K_R4:
+ case GGML_TYPE_IQ2_K_R4:
+ case GGML_TYPE_IQ3_K_R4:
+ case GGML_TYPE_IQ4_K_R4:
+ case GGML_TYPE_IQ5_K_R4:
+ case GGML_TYPE_IQ4_KS_R4:
+ case GGML_TYPE_IQ2_XXS_R4:
+ case GGML_TYPE_IQ2_XS_R4:
+ case GGML_TYPE_IQ2_S_R4:
+ case GGML_TYPE_IQ3_XXS_R4:
+ case GGML_TYPE_IQ3_S_R4: return 4;
+ case GGML_TYPE_IQ4_NL_R4:
+ case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
+ case GGML_TYPE_IQ2_BN_R4:
+ case GGML_TYPE_IQ4_XS_R4:
+ case GGML_TYPE_Q4_K_R4:
+ case GGML_TYPE_Q5_K_R4:
+ case GGML_TYPE_Q8_K_R8: return 8;
+ case GGML_TYPE_Q4_0_R4:
+ case GGML_TYPE_Q8_0_R4:
+ case GGML_TYPE_BF16_R16: return 16;
+ default: return 1;
+ }
+#else
switch (type) {
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K_R4:
@@ -263,6 +301,7 @@ struct MulMat {
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
+#endif
}
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
@@ -377,13 +416,6 @@ const uint64_t keven_signs[128] = {
#if defined __x86_64__
-#if defined HAVE_FANCY_SIMD
- #undef HAVE_FANCY_SIMD
-#endif
-#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
- #define HAVE_FANCY_SIMD
-#endif
-
namespace {
inline float hsum_float_4(__m128 x) {
@@ -2608,6 +2640,15 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2);
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto qy = (const block_q8_1 *)q8.y[0];
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
+ prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
+ auto sumi = accum_q4_0_quants(v, qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
+ acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2);
+ }
acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1);
info.store(ix, 0, acc1);
}
@@ -2645,6 +2686,18 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
+ prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = accum_q4_0_quants(v, qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
@@ -2664,9 +2717,38 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[8];
+ auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 4; ++j) {
+ auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1);
+ qx[j+0] = _mm512_and_si512(bits, m4);
+ qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4);
+ }
+ return scales;
+ };
+ auto dot = [&qx] (const int8_t * qy) {
+ auto y4l = _mm_loadu_si128((const __m128i*)qy+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)qy+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 16) {
const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx);
@@ -2676,47 +2758,25 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
- auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d));
- auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d));
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
- auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1);
- auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1);
- qx[0] = _mm512_and_si512(bits1, m4);
- qx[1] = _mm512_and_si512(bits2, m4);
- qx[2] = _mm512_and_si512(bits3, m4);
- qx[3] = _mm512_and_si512(bits4, m4);
- qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4);
- qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4);
- qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4);
- qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4);
+ auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
- auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
- auto y8l = MM256_SET_M128I(y4l, y4l);
- auto y8h = MM256_SET_M128I(y4h, y4h);
- auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
- auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4l[ib], iq4h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs);
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
@@ -2981,12 +3041,56 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
#endif
#ifdef HAVE_FANCY_SIMD
+inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) {
+ auto y4l = _mm_loadu_si128((const __m128i*)y+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)y+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ return sumi;
+}
+inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) {
+ auto y4l = _mm_loadu_si128((const __m128i*)y+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)y+1);
+ auto yl = MM256_SET_M128I(y4l, y4l);
+ auto yh = MM256_SET_M128I(y4h, y4h);
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
+ return sumi;
+}
+inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) {
+ qx[0] = _mm256_loadu_si256((const __m256i *)x+0);
+ qx[1] = _mm256_loadu_si256((const __m256i *)x+1);
+ qx[2] = _mm256_loadu_si256((const __m256i *)x+2);
+ qx[3] = _mm256_loadu_si256((const __m256i *)x+3);
+ qx[4] = _mm256_loadu_si256((const __m256i *)x+4);
+ qx[5] = _mm256_loadu_si256((const __m256i *)x+5);
+ qx[6] = _mm256_loadu_si256((const __m256i *)x+6);
+ qx[7] = _mm256_loadu_si256((const __m256i *)x+7);
+ return qx_r8_q8_dot_product(qx, y);
+}
template <int nrc_y>
static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%16 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
int nb = n / QK8_0;
- GGML_ASSERT(nb%4 == 0);
if constexpr (nrc_y == 1) {
__m256 acc[2] = {};
__m256i qx[8];
@@ -2997,32 +3101,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)));
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
- qx[0] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
- qx[1] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
- qx[2] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
- qx[3] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
- qx[4] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
- qx[5] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
- qx[6] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
- qx[7] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
- auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0);
- auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1);
- auto yl = MM256_SET_M128I(y4l, y4l);
- auto yh = MM256_SET_M128I(y4h, y4h);
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
- sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
+ auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k]));
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]);
}
}
+ if (4*(nb/4) < nb) {
+ auto qy = (const block_q8_1 *)q8.y[0];
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
+ auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
+ acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]);
+ }
+ }
info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
acc[0] = acc[1] = _mm256_setzero_ps();
}
@@ -3046,27 +3140,29 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1);
}
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
- auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
- auto y8l = MM256_SET_M128I(y4l, y4l);
- auto y8h = MM256_SET_M128I(y4h, y4h);
- auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
- auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k);
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 8; ++j) {
+ qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
+ _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]);
info.store(ix, iy, sum512);
@@ -3082,9 +3178,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK8_0;
- GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[4*nrc_y];
+ __m256i qx[4], sx[4];
+ auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)qy);
+ auto y = MM256_SET_M128I(y128, y128);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
+ );
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
+ );
+ return _mm256_add_epi32(sumi1, sumi2);
+ };
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
@@ -3094,54 +3203,49 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
}
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
- auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
- auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
- auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
- auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
- auto s0 = _mm256_sign_epi8(q0, q0);
- auto s1 = _mm256_sign_epi8(q1, q1);
- auto s2 = _mm256_sign_epi8(q2, q2);
- auto s3 = _mm256_sign_epi8(q3, q3);
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
- auto y = MM256_SET_M128I(y128, y128);
- auto sumi1 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
- );
- auto sumi2 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
- );
- auto sumi = _mm256_add_epi32(sumi1, sumi2);
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
- q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
- q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
- q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
- q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
- s0 = _mm256_sign_epi8(q0, q0);
- s1 = _mm256_sign_epi8(q1, q1);
- s2 = _mm256_sign_epi8(q2, q2);
- s3 = _mm256_sign_epi8(q3, q3);
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
- auto y = MM256_SET_M128I(y128, y128);
- auto sumi1 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
- );
- auto sumi2 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
- );
- auto sumi = _mm256_add_epi32(sumi1, sumi2);
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs+16);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
@@ -7080,6 +7184,7 @@ struct QFBase {
static inline Acc acc_first(const Data& y, const Data& x) {
return _mm512_mul_ps(y, x);
}
+ static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
static inline float hsum(Acc acc) {
return _mm512_reduce_add_ps(acc);
}
@@ -7118,6 +7223,7 @@ struct QFBase {
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return _mm256_fmadd_ps(y, x, prev);
}
+ static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
@@ -7190,6 +7296,44 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase {
const Float * y[nrc];
};
+// TBD if we want this
+//template <typename Qy, typename Qx>
+//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+// static_assert(Qy::nrc == 1);
+// int nb = n/QFBase::k_step;
+// int nb4 = n/4;
+// Qy y(info);
+// Qx x(cx + ix0*bx, bx);
+// QFBase::Data xv[2*Qx::nrc];
+// QFBase::Acc acc[2*Qx::nrc];
+// auto yv1 = y.load1(0, 0);
+// auto yv2 = y.load1(0, 1);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[2*ix+0] = x.load1(ix, 0);
+// xv[2*ix+1] = x.load1(ix, 1);
+// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
+// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
+// }
+// for (int i = 1; i < nb/2; ++i) {
+// yv1 = y.load1(0, 2*i+0);
+// yv2 = y.load1(0, 2*i+1);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[2*ix+0] = x.load1(ix, 2*i+0);
+// xv[2*ix+1] = x.load1(ix, 2*i+1);
+// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
+// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
+// }
+// }
+// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
+// yv1 = y.load_tail(0, i);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[ix] = x.load_tail(ix, i);
+// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
+// }
+// }
+// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
+//}
+
template <typename Qy, typename Qx>
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
int nb = n/QFBase::k_step;
@@ -7287,12 +7431,29 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
template <int nrc_y, typename FloatX, typename FloatY>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const char * cx = (const char *)vx;
+ // TBD if we want this
+ //if constexpr (nrc_y == 1) {
+ // constexpr int k_nx = 2;
+ // for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ // mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
+ // }
+ // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
+ // int nx = nrc_x - lastx;
+ // switch (nx) {
+ // case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
+ // case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
+ // case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
+ // }
+ // //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
+ // }
+ // return;
+ //}
#ifdef __AVX512F__
constexpr int k_nx = 5;
#else
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
#endif
- const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
}
@@ -12146,7 +12307,6 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
int8x16_t qx[16];
float d8[4*nrc_y];
float32x4_t acc[2*nrc_y] = {};
@@ -12168,6 +12328,18 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = deq.prepare(ib, 0, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ auto y = vld1q_s8_x2(qy[ib].qs);
+ auto sumi1 = interleaved_dotq(qx+0, y);
+ auto sumi2 = interleaved_dotq(qx+8, y);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, deq.result(acc[2*iy+0]));
info.store(ix+4, iy, deq.result(acc[2*iy+1]));
@@ -12312,12 +12484,32 @@ struct Q6_0_R4_Dequantizer {
const int8x16_t m32 = vdupq_n_s8(-32);
};
+inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
+ auto y = vld1q_s8_x2(qy);
+ sumi1 = sumi2 = vdupq_n_s32(0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
+}
+
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
- GGML_ASSERT(nb%4 == 0);
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
@@ -12332,32 +12524,29 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
+ int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi1 = vdupq_n_s32(0);
- auto sumi2 = vdupq_n_s32(0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
- sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
- sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
- sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
- sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
- sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
- sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
+ qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
+ int32x4_t sumi1, sumi2;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
@@ -13033,10 +13222,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
- m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
- m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
- m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
- m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
_mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
@@ -13055,10 +13244,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
- m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
- m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
- m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
- m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
_mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
@@ -13895,16 +14084,11 @@ struct FlashQKfp32 {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
- if constexpr (D >= 128) {
#ifdef HAVE_FANCY_SIMD
- MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
#else
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
#endif
- } else {
- // This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
- }
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {