diff options
author | Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> | 2024-04-09 09:16:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-09 11:16:13 +0300 |
commit | 5dc9dd7152dedc6046b646855585bd070c91e8c8 (patch) | |
tree | d2bae3652d91cdd9327e28fa85d167a67e050c53 /ggml-cuda/dmmv.cu | |
parent | e11a8999b5690f810c2c99c14347f0834e68c524 (diff) |
llama : add Command R Plus support (#6491)
* Add Command R Plus GGUF
* Add Command R Plus GGUF
* Loading works up to LayerNorm2D
* Export new tensors in 1D so they are not quantized.
* Fix embedding layer based on Noeda's example
* Whitespace
* Add line
* Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda)
* dranger003: Fix block index overflow in CUDA dequantizing.
* Reverted blocked multiplication code as it still has issues and could affect other Llama arches
* export norms as f32
* fix overflow issues during quant and other cleanup
* Type convention
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* dranger003: Fix more int overflow during quant.
---------
Co-authored-by: S <seast@Ss-Mac-Studio.local>
Co-authored-by: S <s@example.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml-cuda/dmmv.cu')
-rw-r--r-- | ggml-cuda/dmmv.cu | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 0b17e3cb..7313e3e1 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; // automatic half -> float type cast if dfloat == float @@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel> static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block - const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; @@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; - const int ib = (row*ncols + col)/qk; // x block index + const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index |