summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_common.h
blob: dc3e369fdc36345c390ae6e77d2e24a71c3913e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp fenc=utf-8 :vi
//
//
// Copyright (C) 2024 Iwan Kawrakow
// MIT license
// SPDX-License-Identifier: MIT
//

#include "iqk_config.h"

#if defined IQK_IMPLEMENT

#include <cstring>
#include <type_traits>
#include <vector>

#include "ggml-impl.h"
#include "ggml-quants.h"
#include "iqk_mul_mat.h"
#include "iqk_quantize.h"

#define GGML_COMMON_IMPL_C
#include "ggml-common.h"

#define FA_TIMING 0

#include <utility>
#include <array>
#if FA_TIMING
#include <chrono>
#include <mutex>
struct Perf {
    using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
    std::array<double, 5> times = {};
    std::mutex mutex;
    bool report;
    static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
    inline void accum(int what, const TimePoint& t1) {
        auto t2 = cur_time();
        auto dt = delta(t1, t2);
        std::lock_guard<std::mutex> lock(mutex);
        times[what] += dt;
    }
    inline void accum_nolock(int what, const TimePoint& t1) {
        auto t2 = cur_time();
        auto dt = delta(t1, t2);
        times[what] += dt;
    }
    inline void add(const Perf& other) {
        std::lock_guard<std::mutex> lock(mutex);
        for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
    }
    Perf(bool r) : report(r) {}
    ~Perf() {
        if (report) {
            double tot = 0;
            for (auto& t : times) tot += t;
            if (!tot) return;
            printf("======================= Timing: %g ms in total\n", tot);
            for (int i = 0; i < int(times.size()); ++i) {
                if (times[i]) {
                    printf("%d:  %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
                }
            }
        }
    }
    static Perf& instance() {
        static Perf p(true);
        return p;
    }
    static double delta(const TimePoint& t1, const TimePoint& t2) {
        return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
    }
};
#endif

#ifdef __AVX2__
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#endif

namespace {

typedef struct {
    int32_t i1;
    int32_t i2;
} mmid_row_mapping;

struct DataInfo {
    float       * s;
    const char  * cy;
    size_t        bs;
    size_t        by;
    int           cur_y = 0;
    int           ne11;
    const mmid_row_mapping * row_mapping = nullptr;
    size_t        bs2 = 0;

    inline const char * src1_row(int iy) const {
        if (!row_mapping) return cy + (cur_y + iy)*by;
        int i11 = row_mapping[cur_y + iy].i1 % ne11;
        int i12 = row_mapping[cur_y + iy].i2;
        return cy + (i11 + i12*ne11)*by;
    }

    inline void store(int ix, int iy, float result) const {
        *(dst_row(iy) + ix) = result;
    }
#ifdef __AVX__
    inline void store(int ix, int iy, __m128 result) const {
        _mm_storeu_ps(dst_row(iy) + ix, result);
    }
    inline void store(int ix, int iy, __m256 result) const {
        _mm256_storeu_ps(dst_row(iy) + ix, result);
    }
#endif
#ifdef __AVX512F__
    inline void store(int ix, int iy, __m512 result) const {
        _mm512_storeu_ps(dst_row(iy) + ix, result);
    }
#endif
#ifdef __ARM_NEON
    inline void store(int ix, int iy, float32x4_t result) const {
        vst1q_f32(dst_row(iy) + ix, result);
    }
#endif
    inline float * dst_row(int iy) const {
        if (!row_mapping) return s + (cur_y + iy)*bs;
        int i12 = row_mapping[cur_y + iy].i2;
        int i1  = row_mapping[cur_y + iy].i1;
        int i2  = i12;
        return s + i1*bs + i2*bs2;
    }
};

typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);

#endif