summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp251
1 files changed, 251 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index cdd8bd16..db71c03b 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1,6 +1,8 @@
// Defines fileno on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
+#include <cstdint>
+#include <cstdio>
#endif
#include "llama_util.h"
@@ -633,6 +635,7 @@ struct llama_model_loader {
throw format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s",
name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
}
+
return get_tensor_for(lt);
}
@@ -1774,6 +1777,254 @@ int llama_model_quantize(
}
}
+int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
+ fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
+
+ auto & model = ctx->model;
+
+ const int64_t t_start_lora_us = ggml_time_us();
+
+ auto fin = std::ifstream(path_lora, std::ios::binary);
+ if (!fin) {
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora);
+ return 1;
+ }
+
+ // verify magic and version
+ {
+ uint32_t magic;
+ fin.read((char *) &magic, sizeof(magic));
+ if (magic != 'ggla') {
+ fprintf(stderr, "%s: bad file magic\n", __func__);
+ return 1;
+ }
+ uint32_t format_version;
+ fin.read((char *) &format_version, sizeof(format_version));
+
+ if (format_version != 1) {
+ fprintf(stderr, "%s: unsupported file version\n", __func__ );
+ return 1;
+ }
+ }
+
+ int32_t lora_r;
+ int32_t lora_alpha;
+ fin.read((char *) &lora_r, sizeof(lora_r));
+ fin.read((char *) &lora_alpha, sizeof(lora_alpha));
+ float scaling = (float)lora_alpha / (float)lora_r;
+
+ fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
+
+
+ // create a temporary ggml context to store the lora tensors
+ // todo: calculate size from biggest possible tensor
+ std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
+ struct ggml_init_params params;
+ params.mem_size = lora_buf.size();
+ params.mem_buffer = lora_buf.data();
+ params.no_alloc = false;
+
+ ggml_context * lora_ctx = ggml_init(params);
+ std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
+
+ // create a name -> tensor map of the model to accelerate lookups
+ std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
+ for (auto & kv: model.tensors_by_name) {
+ model_tensors.insert(kv);
+ }
+
+
+ // load base model
+ std::unique_ptr<llama_model_loader> model_loader;
+ ggml_context * base_ctx = NULL;
+ llama_buffer base_buf;
+ if (path_base_model) {
+ fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
+ model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));
+
+ size_t ctx_size, mmapped_size;
+ model_loader->calc_sizes(&ctx_size, &mmapped_size);
+ base_buf.resize(ctx_size);
+
+ ggml_init_params base_params;
+ base_params.mem_size = base_buf.size;
+ base_params.mem_buffer = base_buf.addr;
+ base_params.no_alloc = model_loader->use_mmap;
+
+ base_ctx = ggml_init(base_params);
+
+ model_loader->ggml_ctx = base_ctx;
+
+ // maybe this should in llama_model_loader
+ if (model_loader->use_mmap) {
+ model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false));
+ }
+ }
+
+ // read tensors and apply
+ bool warned = false;
+ int n_tensors = 0;
+ while (true) {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ftype;
+
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
+ if (fin.eof()) {
+ break;
+ }
+
+ int32_t ne[2] = { 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+ }
+
+ std::string name(length, 0);
+ fin.read(&name[0], length);
+
+ // check for lora suffix and get the type of tensor
+ const std::string lora_suffix = ".lora";
+ size_t pos = name.rfind(lora_suffix);
+ if (pos == std::string::npos) {
+ fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
+ return 1;
+ }
+
+ std::string lora_type = name.substr(pos + lora_suffix.length());
+ std::string base_name = name;
+ base_name.erase(pos);
+ // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
+
+ if (model_tensors.find(base_name.data()) == model_tensors.end()) {
+ fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
+ return 1;
+ }
+
+ // create ggml tensor
+ ggml_type wtype;
+ switch (ftype) {
+ case 0: wtype = GGML_TYPE_F32; break;
+ case 1: wtype = GGML_TYPE_F16; break;
+ default:
+ {
+ fprintf(stderr, "%s: invalid tensor data type '%d'\n",
+ __func__, ftype);
+ return false;
+ }
+ }
+ ggml_tensor* lora_tensor;
+ if (n_dims == 2) {
+ lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
+ }
+ else {
+ fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
+ return 1;
+ }
+
+ // load tensor data
+ size_t offset = fin.tellg();
+ size_t tensor_data_size = ggml_nbytes(lora_tensor);
+ offset = (offset + 31) & -32;
+ fin.seekg(offset);
+ fin.read((char*)lora_tensor->data, tensor_data_size);
+
+ lora_tensors[name] = lora_tensor;
+
+ // check if we have both A and B tensors and apply
+ if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
+ lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
+
+ ggml_tensor * dest_t = model_tensors[base_name];
+ ggml_tensor * base_t;
+ if (model_loader) {
+ // load from base model
+ if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) {
+ fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
+ return 1;
+ }
+ size_t idx = model_loader->tensors_map.name_to_idx[base_name];
+ llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
+ base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
+ lt.data = (uint8_t *) lt.ggml_tensor->data;
+ model_loader->load_data_for(lt);
+ lt.ggml_tensor->data = lt.data;
+ }
+ else {
+ base_t = dest_t;
+ }
+
+ if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1) {
+ if (!warned) {
+ fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, "
+ "use a f16 or f32 base model with --lora-base\n", __func__);
+ warned = true;
+ }
+ }
+
+ ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
+ ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
+
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
+ fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
+ return 1;
+ }
+
+ // w = w + BA*s
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
+
+ if (scaling != 1.0f) {
+ ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
+ BA = ggml_scale(lora_ctx, BA, scale_tensor);
+ }
+
+ ggml_tensor * r;
+ if (base_t == dest_t) {
+ r = ggml_add_inplace(lora_ctx, dest_t, BA);
+ }
+ else {
+ r = ggml_add(lora_ctx, base_t, BA);
+ r = ggml_cpy(lora_ctx, r, dest_t);
+ }
+
+ struct ggml_cgraph gf = ggml_build_forward(r);
+ gf.n_threads = n_threads;
+ ggml_graph_compute(lora_ctx, &gf);
+
+ // we won't need these tensors again, reset the context to save memory
+ ggml_free(lora_ctx);
+ lora_ctx = ggml_init(params);
+ lora_tensors.clear();
+
+ n_tensors++;
+ if (n_tensors % 4 == 0)
+ fprintf(stderr, ".");
+ }
+ }
+
+ // TODO: this should be in a destructor, it will leak on failure
+ ggml_free(lora_ctx);
+ if (base_ctx) {
+ ggml_free(base_ctx);
+ }
+
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
+ fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0);
+
+ return 0;
+}
+
+int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
+ try {
+ return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
+ } catch (const std::string & err) {
+ fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str());
+ return 1;
+ }
+}
+
// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {