From bf83bff6742c0f1795b4c18695a13a34ac7adf62 Mon Sep 17 00:00:00 2001 From: Shouzheng Liu Date: Wed, 16 Aug 2023 16:07:04 -0400 Subject: metal : matrix-matrix multiplication kernel (#2615) * metal: matrix-matrix multiplication kernel This commit removes MPS and uses custom matrix-matrix multiplication kernels for all quantization types. This commit also adds grouped-query attention to support llama2 70B. * metal: fix performance degradation from gqa Integers are slow on the GPU, and 64-bit divides are extremely slow. In the context of GQA, we introduce a 64-bit divide that cannot be optimized out by the compiler, which results in a decrease of ~8% in inference performance. This commit fixes that issue by calculating a part of the offset with a 32-bit divide. Naturally, this limits the size of a single matrix to ~4GB. However, this limitation should suffice for the near future. * metal: fix bugs for GQA and perplexity test. I mixed up ne02 and nb02 in previous commit. --- llama.cpp | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) (limited to 'llama.cpp') diff --git a/llama.cpp b/llama.cpp index c8ab313d..a161f156 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1845,7 +1845,7 @@ static bool llama_eval_internal( #endif #ifdef GGML_USE_METAL - if (lctx.ctx_metal && N == 1) { + if (lctx.ctx_metal) { // TODO: disabled until #2413 is resolved //if (!ggml_metal_if_optimized(lctx.ctx_metal)) { // ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf); @@ -1857,22 +1857,6 @@ static bool llama_eval_internal( ggml_metal_get_tensor(lctx.ctx_metal, embeddings); } } else { - // IMPORTANT: - // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla - // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX - // coprocessor. - // - // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch. - // But for now, we have focused only on Matrix x Vector Metal multiplication. - // - // TODO: avoid these syncs via shared memory (ref #1696) - // - if (lctx.ctx_metal) { - // We need to sync the GPU KV cache with the CPU KV cache - ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k); - ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); - } - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); } #else -- cgit v1.2.3