diff options
author | Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> | 2024-04-09 09:16:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-09 11:16:13 +0300 |
commit | 5dc9dd7152dedc6046b646855585bd070c91e8c8 (patch) | |
tree | d2bae3652d91cdd9327e28fa85d167a67e050c53 /ggml-cuda/dequantize.cuh | |
parent | e11a8999b5690f810c2c99c14347f0834e68c524 (diff) |
llama : add Command R Plus support (#6491)
* Add Command R Plus GGUF
* Add Command R Plus GGUF
* Loading works up to LayerNorm2D
* Export new tensors in 1D so they are not quantized.
* Fix embedding layer based on Noeda's example
* Whitespace
* Add line
* Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda)
* dranger003: Fix block index overflow in CUDA dequantizing.
* Reverted blocked multiplication code as it still has issues and could affect other Llama arches
* export norms as f32
* fix overflow issues during quant and other cleanup
* Type convention
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* dranger003: Fix more int overflow during quant.
---------
Co-authored-by: S <seast@Ss-Mac-Studio.local>
Co-authored-by: S <s@example.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml-cuda/dequantize.cuh')
-rw-r--r-- | ggml-cuda/dequantize.cuh | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh index b5440063..bd3c2d9d 100644 --- a/ggml-cuda/dequantize.cuh +++ b/ggml-cuda/dequantize.cuh @@ -1,6 +1,6 @@ #include "common.cuh" -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; const dfloat d = x[ib].d; @@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; const dfloat d = x[ib].d; @@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; const dfloat d = x[ib].d; |