diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-06-04 23:34:30 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-04 23:34:30 +0300 |
commit | ecb217db4fcfa3880300ad08531a5fb6bb142d45 (patch) | |
tree | e7a1a1fee49036f2ee46b419fb032966b8e62222 /llama.cpp | |
parent | dcb2ed48268e421baf25adc00d602dad0f415564 (diff) |
llama : Metal inference (#1642)
* mtl : export the LLaMA computation graph
* ci : disable temporary
* mtl : adapt the MNIST example as starter
* mtl : no need for mtl-export tool, add cli arg for main instead
* mtl : export just a small part of the graph for now to make it easier
* mtl : move MSL code into separate file for easy editing
* mtl : initial get_rows_q4_0 kernel
* mtl : confirmed get_rows_q4_0 is working correctly
* mtl : add rms_norm kernel + confirm working
* mtl : add mul kernel + confirm working
* mtl : initial mul_mat Q4 kernel (wrong results)
* mtl : mul_mat fixes (still wrong)
* mtl : another mul_mat Q4 (still does not work)
* mtl : working mul_mat q4
* ggml : fix handling of "view" ops in ggml_graph_import()
* mtl : add rope kernel
* mtl : add reshape and transpose handling
* ggml : store offset as opt arg for ggml_view_xd() operators
* mtl : add cpy kernel + handle view ops
* mtl : confirm f16 x f32 attention mul mat
* mtl : add scale kernel
* mtl : add diag_mask_inf kernel
* mtl : fix soft_max kernel
* ggml : update ggml_nbytes() to handle non-contiguous tensors
* mtl : verify V tensor contents
* mtl : add f32 -> f32 cpy kernel
* mtl : add silu kernel
* mtl : add non-broadcast mul kernel
* mtl : full GPU inference of the computation graph
* mtl : optimize rms_norm and soft_max kernels
* mtl : add f16 mat x f32 vec multiplication kernel
* mtl : fix bug in f16 x f32 mul mat + speed-up computation
* mtl : faster mul_mat_q4_0_f32 kernel
* mtl : fix kernel signature + roll inner loop
* mtl : more threads for rms_norm + better timing
* mtl : remove printfs from inner loop
* mtl : simplify implementation
* mtl : add save/load vocab to ggml file
* mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
* mtl : make it work with main example
Lots of hacks but at least now it generates text
* mtl : preparing for merge
* mtl : clean-up ggml mtl interface + suport scratch / inplace
* mtl : remove temp / debug code
* metal : final refactoring and simplification
* Revert "ci : disable temporary"
This reverts commit 98c267fc77fe811082f672538fc91bcfc9072d63.
* metal : add comments
* metal : clean-up stuff, fix typos
* readme : add Metal instructions
* readme : add example for main
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 132 |
1 files changed, 104 insertions, 28 deletions
@@ -16,6 +16,10 @@ #include "ggml-opencl.h" #endif +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + #include <array> #include <ctime> #include <cinttypes> @@ -243,6 +247,10 @@ struct llama_context { llama_ctx_buffer buf_compute; llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; +#ifdef GGML_USE_METAL + ggml_metal_context * ctx_metal = NULL; +#endif + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -1088,7 +1096,7 @@ static void llama_model_load_internal( mmapped_size - vram_total + // weights in VRAM not in memory MEM_REQ_SCRATCH0().at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at(model.type); + MEM_REQ_EVAL().at (model.type); // this is the memory required by one llama_state const size_t mem_required_state = @@ -1195,17 +1203,19 @@ static bool llama_model_load( // evaluate the transformer // -// - lctx: llama context -// - tokens: new batch of tokens to process -// - n_past: the context size so far -// - n_threads: number of threads to use +// - lctx: llama context +// - tokens: new batch of tokens to process +// - n_past: the context size so far +// - n_threads: number of threads to use +// - cgraph_fname: filename of the exported computation graph // static bool llama_eval_internal( - llama_context & lctx, - const llama_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads) { + llama_context & lctx, + const llama_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads, + const char * cgraph_fname) { // enforce that the first token is BOS if (n_past == 0 && tokens[0] != llama_token_bos()) { @@ -1251,13 +1261,18 @@ static bool llama_eval_internal( ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); +#ifdef GGML_USE_METAL + if (lctx.ctx_metal && N == 1) { + ggml_metal_set_tensor(lctx.ctx_metal, embd); + } +#endif + + struct ggml_tensor * cur; struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; - struct ggml_tensor * cur; - lctx.use_buf(ctx0, 0); // norm @@ -1271,6 +1286,7 @@ static bool llama_eval_internal( // self-attention { // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); ggml_set_name(Qcur, "Qcur"); @@ -1280,6 +1296,7 @@ static bool llama_eval_internal( { // compute the transposed [N, n_embd] V matrix struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); + ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, @@ -1325,7 +1342,6 @@ static bool llama_eval_internal( struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); ggml_set_name(KQ_soft_max, "KQ_soft_max"); - // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, @@ -1407,26 +1423,53 @@ static bool llama_eval_internal( // norm { + cur = ggml_rms_norm(ctx0, inpL); - inpL = ggml_rms_norm(ctx0, inpL); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.norm); - // inpL = inpL*norm(broadcasted) - inpL = ggml_mul(ctx0, inpL, model.norm); - - embeddings = inpL; + embeddings = cur; } // lm_head - inpL = ggml_mul_mat(ctx0, model.output, inpL); + cur = ggml_mul_mat(ctx0, model.output, cur); lctx.use_buf(ctx0, -1); // logits -> probs - //inpL = ggml_soft_max_inplace(ctx0, inpL); + //cur = ggml_soft_max_inplace(ctx0, cur); // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); + ggml_build_forward_expand(&gf, cur); + +#ifdef GGML_USE_METAL + if (lctx.ctx_metal && N == 1) { + ggml_metal_graph_compute(lctx.ctx_metal, &gf); + ggml_metal_get_tensor (lctx.ctx_metal, cur); + } else { + // IMPORTANT: + // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla + // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX + // coprocessor. + // + // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch. + // But for now, we have focused only on Matrix x Vector Metal multiplication. + // + ggml_graph_compute(ctx0, &gf); + + if (lctx.ctx_metal) { + // We need to sync the CPU KV cache with the GPU KV cache + ggml_metal_set_tensor(lctx.ctx_metal, kv_self.k); + ggml_metal_set_tensor(lctx.ctx_metal, kv_self.v); + } + } +#else + ggml_graph_compute(ctx0, &gf); +#endif + + if (cgraph_fname) { + ggml_graph_export(&gf, cgraph_fname); + } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -1440,7 +1483,7 @@ static bool llama_eval_internal( //} //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); // update kv token count lctx.model.kv_self.n = n_past + N; @@ -1451,11 +1494,11 @@ static bool llama_eval_internal( if (lctx.logits_all) { logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); } else { // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } } @@ -2251,8 +2294,8 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type, - params.use_mmap, params.use_mlock, params.vocab_only, - params.progress_callback, params.progress_callback_user_data)) { + params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); return nullptr; @@ -2290,6 +2333,25 @@ struct llama_context * llama_init_from_file( ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); } +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + // this allocates all Metal resources and memory buffers + ctx->ctx_metal = ggml_metal_init(); + + if (params.use_mmap) { + ggml_metal_add_buffer(ctx->ctx_metal, "data", ctx->model.mapping->addr, ctx->model.mapping->size); + ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size); + } else { + ggml_metal_add_buffer(ctx->ctx_metal, "data", ggml_get_mem_buffer(ctx->model.ctx), ggml_get_mem_size(ctx->model.ctx)); + ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size); + } + + ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size); + ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size); + ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size); + } +#endif + return ctx; } @@ -2905,7 +2967,7 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2920,6 +2982,20 @@ int llama_eval( return 0; } +int llama_eval_export(struct llama_context * ctx, const char * fname) { + const int n_batch = 1; + const int n_ctx = 512 - n_batch; + + const std::vector<llama_token> tmp(n_batch, llama_token_bos()); + + if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + int llama_tokenize( struct llama_context * ctx, const char * text, |