1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
|
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
//
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
// A is a row-major matrix with shape M x K.
// B is a column-major matrix with shape K x N.
// C is a column-major matrix with shape M x N.
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
// Note that J is measured in physical 32 bit elements instead of logical elements.
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
#include "common.cuh"
#if CUDART_VERSION >= 11080
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef INT8_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(INT8_MMA_AVAILABLE)
return ret;
}
#else
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
// Imagine transposing row-major matrix to column-major matrix.
const int src_i_low = 2 * (threadIdx.x % 4);
const int src_i_high = src_i_low + 1;
const int src_j = threadIdx.x / 4;
const int src_laneid_low = src_i_low * 4 + src_j / 2;
const int src_laneid_high = src_i_high * 4 + src_j / 2;
const int shift_low = ((src_j + 0) % 2) * 16;
const int shift_high = ((src_j + 1) % 2) * 16;
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 11080
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
half2 ret;
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
return ret;
}
namespace ggml_cuda_mma {
template <int I_, int J_, typename T>
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
T x[ne] = {0};
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 16) {
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
} else if constexpr (I == 16 && J == 16) {
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I_, int J_>
struct tile<I_, J_, half2> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
#pragma unroll
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
}
return ret;
}
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
tile<8, 8, half2> ret;
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
return ret;
}
template <int I, int J, typename T>
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef INT8_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // INT8_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef INT8_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#endif // INT8_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef INT8_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // INT8_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef INT8_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(t);
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[1]), "r"(B.x[0]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[1]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[2]), "r"(B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[3]), "r"(B.x[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef INT8_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef INT8_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef INT8_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef INT8_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
}
|