diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-27 08:16:06 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-27 08:16:06 +0300 |
commit | 6dec4af4b6e65eb72e646a6f8b10d77c9d306281 (patch) | |
tree | b69a6dfdd024ccf6a4d7490666664cbac4bc65ce /ggml/include/ggml.h | |
parent | 546f3ef349a7082fbc349897c3c7246baed2a6c6 (diff) |
Adding ability to have meta data per tensor row (#61)
* POC: per row scale
This is a POC how to work around opinionated ggml to
have scales per row rather than per block.
Only implemened for Zen4 and only for iq2_tn.
* POC per row scale: iq2_tn on NEON
* POC per row scale: iq2_tn on Metal
* Per row scale Metal templates
* iq1_tn: shrink to 1.625 bpw (NEON and Metal)
* POC per row scale: CUDA
* POC per row scale: add CUDA TODOs
There are two places in ggml-cuda.cu left where it is assumed
that type_size * n_per_row / block_size is the way to compute
and handle row sizes. This does not affect simple usage,
but will lead to issues when tensors are split between GPUs.
* Per row scales - CUDA
The only place left where there are unnecessary assumptions being made
is in the Flash Attention code. As we are not using any quants that
use per row scales for quantized KV cache, it should be OK for now.
* Update IQ1_TN and IQ2_TN bpw shown to user
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/include/ggml.h')
-rw-r--r-- | ggml/include/ggml.h | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5b46a70d..6ac30b0f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -744,6 +744,7 @@ extern "C" { GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + // TODO: remove the following from the public API to avoid unnecessary assumptions about data layout GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row @@ -2517,6 +2518,7 @@ extern "C" { int64_t ncols; // number of columns to process simultaneously ggml_gemv_t gemv; ggml_gemm_t gemm; + int64_t row_meta_size; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); |