diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 383 |
1 files changed, 370 insertions, 13 deletions
@@ -188,6 +188,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, + LLM_ARCH_BLOOM, LLM_ARCH_UNKNOWN, }; @@ -201,7 +202,8 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = { { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_PERSIMMON, "persimmon" }, - { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BLOOM, "bloom" }, }; enum llm_kv { @@ -304,6 +306,7 @@ struct LLM_KV { enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_POS_EMBD, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, @@ -467,6 +470,21 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = }, }, { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { LLM_ARCH_UNKNOWN, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -1207,6 +1225,8 @@ struct llama_model { struct ggml_tensor * tok_embeddings; struct ggml_tensor * pos_embeddings; + struct ggml_tensor * tok_norm; + struct ggml_tensor * tok_norm_b; struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; @@ -2056,13 +2076,13 @@ static void llm_load_hparams( } } break; case LLM_ARCH_PERSIMMON: - { - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - switch (hparams.n_layer) { - case 36: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + switch (hparams.n_layer) { + case 36: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_REFACT: { GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); @@ -2071,6 +2091,19 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BLOOM: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + } break; + } + } break; case LLM_ARCH_MPT: { hparams.f_clamp_kqv = 0.0f; @@ -2676,6 +2709,88 @@ static void llm_load_tensors( layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); } } break; + case LLM_ARCH_BLOOM: + { + // TODO: CPU-only for now + + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); + model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + + ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3) + + ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2); + } + } + } break; case LLM_ARCH_MPT: { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); @@ -4996,6 +5111,248 @@ static struct ggml_cgraph * llm_build_persimmon( return gf; } +static struct ggml_cgraph * llm_build_bloom( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float norm_eps = hparams.f_norm_eps; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * token; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, token); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); + } + } + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // norm + { + inpL = ggml_norm(ctx0, token, norm_eps); + inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.tok_norm), model.tok_norm_b); + } + + ggml_set_name(inpL, "inpL"); + + for (int il = 0; il < n_layer; ++il) { + { + // Norm + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + } + + { + // Self Attention + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); + + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + + struct ggml_tensor * Qcur = tmpq; + struct ggml_tensor * Kcur = tmpk; + + // store key and value to memory + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), + 0, 2, 1, 3); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ kv_head, n_head, 8); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + ggml_set_name(cur, "KQV_merged_contiguous"); + } + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // FF + { + // Norm + { + cur = ggml_norm(ctx0, inpFF, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + } + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + // Output Norm + { + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); + } + ggml_set_name(cur, "result_norm"); + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + static struct ggml_cgraph * llm_build_mpt( llama_context & lctx, const llama_batch & batch) { @@ -5025,9 +5382,6 @@ static struct ggml_cgraph * llm_build_mpt( const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", - // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); - auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -5348,6 +5702,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_refact(lctx, batch); } break; + case LLM_ARCH_BLOOM: + { + result = llm_build_bloom(lctx, batch); + } break; case LLM_ARCH_MPT: { result = llm_build_mpt(lctx, batch); @@ -7579,8 +7937,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(meta); // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos || - name.find("attn_qkv.weight") != std::string::npos) { + if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { ++n_attention_wv; } else if (name.find("ffn_down.weight") != std::string::npos) { |