// // Copyright (C) 2024 Iwan Kawrakow // MIT license // SPDX-License-Identifier: MIT // #pragma once #include #include #include "iqk_config.h" #ifdef __cplusplus extern "C" { #endif IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth); IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, long ne02, long ne03, long ne12, long ne13, long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth); IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); typedef void (*barrier_t) (void *); IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, int neq3, int neq2, long nbq3, long nbq2, int nek3, int nek2, long nbk3, long nbk2, int nev3, int nev2, long nbv3, long nbv2, int ne2, int ne1, long nb1, int type_k, // type of k int type_v, // type of v int Dk, // K head size int Dv, // V head size int nq, // number of columns in q int nk, // number of rows in k int stride_q, // distance between q columns in bytes int stride_k, // distance between k rows in bytes int stride_v, // distance between v rows in bytes int stride_m, // distance between mask rows (in bytes const void * q, // q matrix. const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) void * work_buffer, barrier_t barrier, void * barrier_data, int ith, int nth); #ifdef __cplusplus } #endif