summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorCarolinabanana <140120812+Carolinabanana@users.noreply.github.com>2024-04-09 09:16:13 +0100
committerGitHub <noreply@github.com>2024-04-09 11:16:13 +0300
commit5dc9dd7152dedc6046b646855585bd070c91e8c8 (patch)
treed2bae3652d91cdd9327e28fa85d167a67e050c53 /ggml-cuda.cu
parente11a8999b5690f810c2c99c14347f0834e68c524 (diff)
llama : add Command R Plus support (#6491)
* Add Command R Plus GGUF * Add Command R Plus GGUF * Loading works up to LayerNorm2D * Export new tensors in 1D so they are not quantized. * Fix embedding layer based on Noeda's example * Whitespace * Add line * Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda) * dranger003: Fix block index overflow in CUDA dequantizing. * Reverted blocked multiplication code as it still has issues and could affect other Llama arches * export norms as f32 * fix overflow issues during quant and other cleanup * Type convention Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * dranger003: Fix more int overflow during quant. --------- Co-authored-by: S <seast@Ss-Mac-Studio.local> Co-authored-by: S <s@example.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu6
1 files changed, 3 insertions, 3 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index ce28cb55..bff8ad9d 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
- int ldc = id == ctx.device ? ne0 : row_diff;
+ int64_t ldc = id == ctx.device ? ne0 : row_diff;
const int compute_capability = ggml_cuda_info().devices[id].cc;
@@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
- const int nb2 = dst->nb[2];
- const int nb3 = dst->nb[3];
+ const int64_t nb2 = dst->nb[2];
+ const int64_t nb3 = dst->nb[3];
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));