summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-02-12 09:16:06 +0200
committerGitHub <noreply@github.com>2024-02-12 09:16:06 +0200
commit3b169441dfe8e420f88d1592708cc2a871daadb9 (patch)
treeb554c9eac1b3b7dbf11e364b6a4a748605a6e949 /examples
parent3bdc4cd0f595a6096cca4a64aa75ffa8a3503465 (diff)
sync : ggml (#5452)
* ggml-alloc : v3 (ggml/727) * ggml-alloc v3 ggml-ci * fix ci ggml-ci * whisper : check for backend buffer allocation failures * whisper : avoid leaks when initialization fails * cleanup ggml-ci * style fixes ggml-ci * sync : ggml * update llama.cpp, clip.cpp, export-lora.cpp * update finetune.cpp, train-text-from-scratch.cpp ggml-ci * ggml-backend : reduce alignment to 32 to match gguf and fix mmap --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/export-lora/export-lora.cpp19
-rw-r--r--examples/finetune/finetune.cpp145
-rw-r--r--examples/llava/clip.cpp152
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp112
4 files changed, 155 insertions, 273 deletions
diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp
index 4cd5d99b..2f7be8a1 100644
--- a/examples/export-lora/export-lora.cpp
+++ b/examples/export-lora/export-lora.cpp
@@ -337,24 +337,14 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
params.mem_buffer = NULL;
params.no_alloc = true;
struct ggml_context * ctx = NULL;
- struct ggml_allocr * alloc = NULL;
- struct ggml_cgraph * gf = NULL;
+ struct ggml_gallocr * alloc = NULL;
+ struct ggml_cgraph * gf = NULL;
ctx = ggml_init(params);
- alloc = ggml_allocr_new_measure(tensor_alignment);
+ alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
- size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
- ggml_allocr_free(alloc);
- ggml_free(ctx);
-
- static std::vector<uint8_t> data_compute;
- data_compute.resize(alloc_size + tensor_alignment);
- ctx = ggml_init(params);
- alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
- gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
- ggml_allocr_alloc_graph(alloc, gf);
- ggml_allocr_free(alloc);
+ ggml_gallocr_alloc_graph(alloc, gf);
struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
static std::vector<uint8_t> data_work;
@@ -363,6 +353,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
ggml_graph_compute(gf, &cplan);
+ ggml_gallocr_free(alloc);
ggml_free(ctx);
return true;
}
diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp
index b7e19c5f..b11c5602 100644
--- a/examples/finetune/finetune.cpp
+++ b/examples/finetune/finetune.cpp
@@ -1,5 +1,6 @@
#include "ggml.h"
#include "ggml-alloc.h"
+#include "ggml-backend.h"
#include "llama.h"
#include "common.h"
#include "train.h"
@@ -13,8 +14,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static const size_t tensor_alignment = 32;
-
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512;
@@ -128,7 +127,7 @@ struct my_llama_lora_layer {
struct my_llama_lora {
struct ggml_context * ctx = NULL;
- std::vector<uint8_t> data;
+ ggml_backend_buffer_t data;
my_llama_lora_hparams hparams;
@@ -372,63 +371,6 @@ static void set_param_lora(struct my_llama_lora * lora) {
}
}
-static void alloc_lora(struct ggml_allocr * alloc, struct my_llama_lora * lora) {
- ggml_allocr_alloc(alloc, lora->tok_embeddings_a);
- ggml_allocr_alloc(alloc, lora->tok_embeddings_b);
- ggml_allocr_alloc(alloc, lora->norm_a);
- ggml_allocr_alloc(alloc, lora->norm_b);
- ggml_allocr_alloc(alloc, lora->output_a);
- ggml_allocr_alloc(alloc, lora->output_b);
- for (uint32_t i = 0; i < lora->layers.size(); ++i) {
- auto & layer = lora->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm_a);
- ggml_allocr_alloc(alloc, layer.attention_norm_b);
- ggml_allocr_alloc(alloc, layer.wq_a);
- ggml_allocr_alloc(alloc, layer.wq_b);
- ggml_allocr_alloc(alloc, layer.wk_a);
- ggml_allocr_alloc(alloc, layer.wk_b);
- ggml_allocr_alloc(alloc, layer.wv_a);
- ggml_allocr_alloc(alloc, layer.wv_b);
- ggml_allocr_alloc(alloc, layer.wo_a);
- ggml_allocr_alloc(alloc, layer.wo_b);
- ggml_allocr_alloc(alloc, layer.ffn_norm_a);
- ggml_allocr_alloc(alloc, layer.ffn_norm_b);
- ggml_allocr_alloc(alloc, layer.w1_a);
- ggml_allocr_alloc(alloc, layer.w1_b);
- ggml_allocr_alloc(alloc, layer.w2_a);
- ggml_allocr_alloc(alloc, layer.w2_b);
- ggml_allocr_alloc(alloc, layer.w3_a);
- ggml_allocr_alloc(alloc, layer.w3_b);
- }
- ggml_allocr_alloc(alloc, lora->tok_embeddings_a->grad);
- ggml_allocr_alloc(alloc, lora->tok_embeddings_b->grad);
- ggml_allocr_alloc(alloc, lora->norm_a->grad);
- ggml_allocr_alloc(alloc, lora->norm_b->grad);
- ggml_allocr_alloc(alloc, lora->output_a->grad);
- ggml_allocr_alloc(alloc, lora->output_b->grad);
- for (uint32_t i = 0; i < lora->layers.size(); ++i) {
- auto & layer = lora->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm_a->grad);
- ggml_allocr_alloc(alloc, layer.attention_norm_b->grad);
- ggml_allocr_alloc(alloc, layer.wq_a->grad);
- ggml_allocr_alloc(alloc, layer.wq_b->grad);
- ggml_allocr_alloc(alloc, layer.wk_a->grad);
- ggml_allocr_alloc(alloc, layer.wk_b->grad);
- ggml_allocr_alloc(alloc, layer.wv_a->grad);
- ggml_allocr_alloc(alloc, layer.wv_b->grad);
- ggml_allocr_alloc(alloc, layer.wo_a->grad);
- ggml_allocr_alloc(alloc, layer.wo_b->grad);
- ggml_allocr_alloc(alloc, layer.ffn_norm_a->grad);
- ggml_allocr_alloc(alloc, layer.ffn_norm_b->grad);
- ggml_allocr_alloc(alloc, layer.w1_a->grad);
- ggml_allocr_alloc(alloc, layer.w1_b->grad);
- ggml_allocr_alloc(alloc, layer.w2_a->grad);
- ggml_allocr_alloc(alloc, layer.w2_b->grad);
- ggml_allocr_alloc(alloc, layer.w3_a->grad);
- ggml_allocr_alloc(alloc, layer.w3_b->grad);
- }
-}
-
static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
const auto & lparams = lora->hparams;
@@ -522,18 +464,8 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
set_param_lora(lora);
- // measure data size
- size_t size = 0;
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
- }
-
- // allocate data
- struct ggml_allocr * alloc = NULL;
- lora->data.resize(size + tensor_alignment);
- alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
- alloc_lora(alloc, lora);
- ggml_allocr_free(alloc);
+ // allocate data for lora tensors
+ lora->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
}
static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
@@ -579,7 +511,7 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
static struct ggml_tensor * llama_build_lora_finetune_graphs(
struct my_llama_model * model,
struct my_llama_lora * lora,
- struct ggml_allocr * alloc,
+ ggml_gallocr_t alloc,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
@@ -590,7 +522,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_tokens,
const int n_batch,
const bool enable_flash_attn,
- const bool enable_checkpointing) {
+ const bool enable_checkpointing,
+ const bool measure_only) {
ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0;
@@ -622,13 +555,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
- ggml_allocr_alloc(alloc, KQ_pos);
- if (!ggml_allocr_is_measure(alloc)) {
- int * data = (int *) KQ_pos->data;
- for (int i = 0; i < N; ++i) {
- data[i] = n_past + i;
- }
- }
+ ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
@@ -780,7 +707,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
- ggml_allocr_alloc(alloc, t36->grad);
+ ggml_set_input(t36->grad);
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
@@ -805,11 +732,23 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// note: they will be freed in reverse order
for (unsigned int i = 0; i < checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
- ggml_allocr_alloc(alloc, checkpoints[i]);
+ ggml_set_input(checkpoints[i]);
}
}
- ggml_allocr_alloc_graph(alloc, gb);
+ if (measure_only) {
+ ggml_gallocr_reserve(alloc, gb);
+ } else {
+ ggml_gallocr_alloc_graph(alloc, gb);
+
+ // set KQ_pos
+ {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < N; ++i) {
+ data[i] = n_past + i;
+ }
+ }
+ }
// remove the additional nodes and leafs
for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
@@ -1663,7 +1602,7 @@ int main(int argc, char ** argv) {
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
- printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f));
+ printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)), (float) (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)) / (1024.0f*1024.0f));
if (params.only_write_lora) {
save_train_files_data save_data;
@@ -1690,10 +1629,6 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
-
- std::vector<uint8_t> mem_input_data;
- std::vector<uint8_t> mem_compute_data;
-
// context for input tensors without their data
struct ggml_init_params ctx_input_params = {
ggml_tensor_overhead() * 2, // mem_size
@@ -1706,18 +1641,12 @@ int main(int argc, char ** argv) {
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
+ // allocate input tensors
// measure required memory for input tensors
- size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
- GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
- tensor_alignment;
+ ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
+ size_t max_input_size = ggml_backend_buffer_get_size(input_data);
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
- // allocate input tensors
- mem_input_data.resize(max_input_size);
- ggml_allocr_t alloc_inps = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
- ggml_allocr_alloc(alloc_inps, tokens_input);
- ggml_allocr_alloc(alloc_inps, target_probs);
-
// context for compute tensors without their data
const size_t estimated_compute_size_wo_data = (
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
@@ -1743,7 +1672,7 @@ int main(int argc, char ** argv) {
// find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
- ggml_allocr_t alloc = ggml_allocr_new_measure(tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1756,14 +1685,15 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ true
);
- size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size;
best_order = gf->order;
}
- ggml_allocr_free(alloc);
+ ggml_gallocr_free(alloc);
ggml_free(ctx_compute);
}
size_t max_compute_size = best_compute_size;
@@ -1774,9 +1704,8 @@ int main(int argc, char ** argv) {
"invalid");
// allocate compute tensors
- mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
- ggml_allocr_t alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1789,11 +1718,9 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ false
);
- ggml_allocr_free(alloc);
- ggml_allocr_free(alloc_inps);
-
// tokenize data
std::vector<llama_token> train_tokens;
@@ -1908,6 +1835,8 @@ int main(int argc, char ** argv) {
ggml_free(ctx_work);
ggml_free(ctx_compute);
ggml_free(ctx_input);
+ ggml_gallocr_free(alloc);
+
int64_t t1 = ggml_time_ms();
printf("%s: total training time: ", __func__);
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 9129052a..ccd0d85a 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -367,7 +367,7 @@ struct clip_ctx {
ggml_backend_buffer_t params_buffer = NULL;
ggml_backend_buffer_t compute_buffer = NULL;
ggml_backend_t backend = NULL;
- ggml_allocr * compute_alloc = NULL;
+ ggml_gallocr_t compute_alloc = NULL;
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
@@ -405,31 +405,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
- ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
-
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- float * data = (float *)malloc(ggml_nbytes(inp_raw));
-
- for (size_t i = 0; i < imgs->size; i++) {
- const int nx = imgs->data[i].nx;
- const int ny = imgs->data[i].ny;
- GGML_ASSERT(nx == image_size && ny == image_size);
-
- const int n = nx * ny;
-
- for (int b = 0; b < batch_size; b++) {
- for (int k = 0; k < 3; k++) {
- for (int y = 0; y < ny; y++) {
- for (int x = 0; x < nx; x++) {
- data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
- }
- }
- }
- }
- }
- ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
- free(data);
- }
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@@ -438,13 +415,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// concat class_embeddings and patch_embeddings
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
- ggml_allocr_alloc(ctx->compute_alloc, embeddings);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- void* zero_mem = malloc(ggml_nbytes(embeddings));
- memset(zero_mem, 0, ggml_nbytes(embeddings));
- ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
- free(zero_mem);
- }
+ ggml_set_name(embeddings, "embeddings");
+ ggml_set_input(embeddings);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
@@ -453,15 +425,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
- ggml_allocr_alloc(ctx->compute_alloc, positions);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- int* positions_data = (int*)malloc(ggml_nbytes(positions));
- for (int i = 0; i < num_positions; i++) {
- positions_data[i] = i;
- }
- ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
- free(positions_data);
- }
+ ggml_set_name(positions, "positions");
+ ggml_set_input(positions);
embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
@@ -560,15 +525,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
- ggml_allocr_alloc(ctx->compute_alloc, patches);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- int* patches_data = (int*)malloc(ggml_nbytes(patches));
- for (int i = 0; i < num_patches; i++) {
- patches_data[i] = i + 1;
- }
- ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
- free(patches_data);
- }
+ ggml_set_name(patches, "patches");
+ ggml_set_input(patches);
// shape [1, 576, 1024]
// ne is whcn, ne = [1024, 576, 1, 1]
@@ -809,7 +767,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
// data
- size_t buffer_size = 0;
+ size_t model_size = 0;
{
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
@@ -817,7 +775,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
enum ggml_type type = gguf_get_tensor_type(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
size_t tensor_size = ggml_nbytes(cur);
- buffer_size += tensor_size;
+ model_size += tensor_size;
if (verbosity >= 3) {
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
__func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
@@ -825,8 +783,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
}
- buffer_size += n_tensors * 128 /* CLIP PADDING */;
-
clip_ctx * new_clip = new clip_ctx;
// update projector type
@@ -886,12 +842,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
- printf("%s: model size: %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
+ printf("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
}
}
- printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
+ printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
// load tensors
{
@@ -925,12 +881,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
// alloc memory and offload data
- new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
- ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
+ new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
- ggml_allocr_alloc(alloc, cur);
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
fin.seekg(offset, std::ios::beg);
if (!fin) {
@@ -949,7 +903,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
}
}
- ggml_allocr_free(alloc);
fin.close();
}
@@ -1077,15 +1030,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// measure mem requirement and allocate
{
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
- new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
+ new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
clip_image_f32_batch batch;
batch.size = 1;
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
- size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
- ggml_allocr_free(new_clip->compute_alloc);
- new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
- new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
-
+ ggml_gallocr_reserve(new_clip->compute_alloc, gf);
+ size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
}
@@ -1267,12 +1217,72 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
}
- // reset alloc buffer to clean the memory from previous invocations
- ggml_allocr_reset(ctx->compute_alloc);
-
// build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
- ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
+ ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+ // set inputs
+ const auto & model = ctx->vision_model;
+ const auto & hparams = model.hparams;
+ const int image_size = hparams.image_size;
+ const int patch_size = hparams.patch_size;
+ const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
+ const int num_positions = num_patches + 1;
+
+ {
+ struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+ float * data = (float *)malloc(ggml_nbytes(inp_raw));
+
+ for (size_t i = 0; i < imgs->size; i++) {
+ const int nx = imgs->data[i].nx;
+ const int ny = imgs->data[i].ny;
+ GGML_ASSERT(nx == image_size && ny == image_size);
+
+ const int n = nx * ny;
+
+ for (int b = 0; b < batch_size; b++) {
+ for (int k = 0; k < 3; k++) {
+ for (int y = 0; y < ny; y++) {
+ for (int x = 0; x < nx; x++) {
+ data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+ }
+ }
+ }
+ }
+ }
+ ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+ free(data);
+ }
+
+ {
+ struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+
+ void* zero_mem = malloc(ggml_nbytes(embeddings));
+ memset(zero_mem, 0, ggml_nbytes(embeddings));
+ ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+ free(zero_mem);
+ }
+
+ {
+ struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+ int* positions_data = (int*)malloc(ggml_nbytes(positions));
+ for (int i = 0; i < num_positions; i++) {
+ positions_data[i] = i;
+ }
+ ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+ free(positions_data);
+ }
+
+ {
+ struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+ int* patches_data = (int*)malloc(ggml_nbytes(patches));
+ for (int i = 0; i < num_patches; i++) {
+ patches_data[i] = i + 1;
+ }
+ ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+ free(patches_data);
+ }
if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index eee9d4de..2e2a8ce0 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -1,5 +1,6 @@
#include "ggml.h"
#include "ggml-alloc.h"
+#include "ggml-backend.h"
#include "common.h"
#include "train.h"
#include "llama.h"
@@ -19,8 +20,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static const size_t tensor_alignment = 32;
-
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512;
@@ -58,7 +57,7 @@ struct my_llama_layer {
struct my_llama_model {
struct ggml_context * ctx = NULL;
- std::vector<uint8_t> data;
+ ggml_backend_buffer_t data = NULL;
my_llama_hparams hparams;
@@ -147,39 +146,6 @@ static void set_param_model(struct my_llama_model * model) {
}
}
-static void alloc_model(struct ggml_allocr * alloc, struct my_llama_model * model) {
- ggml_allocr_alloc(alloc, model->tok_embeddings);
- ggml_allocr_alloc(alloc, model->norm);
- ggml_allocr_alloc(alloc, model->output);
- for (uint32_t i = 0; i < model->layers.size(); ++i) {
- auto & layer = model->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm);
- ggml_allocr_alloc(alloc, layer.wq);
- ggml_allocr_alloc(alloc, layer.wk);
- ggml_allocr_alloc(alloc, layer.wv);
- ggml_allocr_alloc(alloc, layer.wo);
- ggml_allocr_alloc(alloc, layer.ffn_norm);
- ggml_allocr_alloc(alloc, layer.w1);
- ggml_allocr_alloc(alloc, layer.w2);
- ggml_allocr_alloc(alloc, layer.w3);
- }
- ggml_allocr_alloc(alloc, model->tok_embeddings->grad);
- ggml_allocr_alloc(alloc, model->norm->grad);
- ggml_allocr_alloc(alloc, model->output->grad);
- for (uint32_t i = 0; i < model->layers.size(); ++i) {
- auto & layer = model->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm->grad);
- ggml_allocr_alloc(alloc, layer.wq->grad);
- ggml_allocr_alloc(alloc, layer.wk->grad);
- ggml_allocr_alloc(alloc, layer.wv->grad);
- ggml_allocr_alloc(alloc, layer.wo->grad);
- ggml_allocr_alloc(alloc, layer.ffn_norm->grad);
- ggml_allocr_alloc(alloc, layer.w1->grad);
- ggml_allocr_alloc(alloc, layer.w2->grad);
- ggml_allocr_alloc(alloc, layer.w3->grad);
- }
-}
-
static void init_model(struct my_llama_model * model) {
const auto & hparams = model->hparams;
@@ -252,17 +218,8 @@ static void init_model(struct my_llama_model * model) {
set_param_model(model);
- // measure data size
- size_t size = 0;
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
- }
-
// allocate data
- struct ggml_allocr * alloc = NULL;
- model->data.resize(size + tensor_alignment);
- alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
- alloc_model(alloc, model);
+ model->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
}
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
@@ -297,7 +254,7 @@ static void randomize_model(struct my_llama_model * model, int seed, float mean,
static struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model,
- struct ggml_allocr * alloc,
+ ggml_gallocr_t alloc,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
@@ -308,7 +265,8 @@ static struct ggml_tensor * llama_build_train_graphs(
const int n_tokens,
const int n_batch,
const bool enable_flash_attn,
- const bool enable_checkpointing) {
+ const bool enable_checkpointing,
+ const bool measure_only) {
ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0;
@@ -334,13 +292,7 @@ static struct ggml_tensor * llama_build_train_graphs(
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
- ggml_allocr_alloc(alloc, KQ_pos);
- if (!ggml_allocr_is_measure(alloc)) {
- int * data = (int *) KQ_pos->data;
- for (int i = 0; i < N; ++i) {
- data[i] = n_past + i;
- }
- }
+ ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
@@ -448,21 +400,31 @@ static struct ggml_tensor * llama_build_train_graphs(
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
-
- ggml_allocr_alloc(alloc, t36->grad);
+ ggml_set_input(t36->grad);
// allocating checkpoints in one block to reduce memory fragmentation
// note: they will be freed in reverse order
for (int i = 0; i < (int) checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
- ggml_allocr_alloc(alloc, checkpoints[i]);
+ ggml_set_input(checkpoints[i]);
}
}
//int n_leafs_after = gb->n_leafs;
//int n_nodes_after = gb->n_nodes;
+ if (measure_only) {
+ // FIXME: will still allocate
+ ggml_gallocr_reserve(alloc, gb);
+ } else {
+ ggml_gallocr_alloc_graph(alloc, gb);
- ggml_allocr_alloc_graph(alloc, gb);
+ if (!measure_only) {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < N; ++i) {
+ data[i] = n_past + i;
+ }
+ }
+ }
// remove the additional nodes and leafs
for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
@@ -1046,7 +1008,7 @@ int main(int argc, char ** argv) {
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
- printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f));
+ printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)), (float) (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)) / (1024.0f*1024.0f));
if (params.only_write_model) {
save_train_files_data save_data;
@@ -1073,11 +1035,6 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
- std::vector<uint8_t> mem_input_data;
- std::vector<uint8_t> mem_compute_data;
-
- ggml_allocr * alloc = NULL;
-
// context for input tensors without their data
struct ggml_init_params ctx_input_params = {
ggml_tensor_overhead() * 2, // mem_size
@@ -1091,16 +1048,10 @@ int main(int argc, char ** argv) {
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
// measure required memory for input tensors
- size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
- GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
- tensor_alignment;
- printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
-
// allocate input tensors
- mem_input_data.resize(max_input_size);
- alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
- ggml_allocr_alloc(alloc, tokens_input);
- ggml_allocr_alloc(alloc, target_probs);
+ ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
+ size_t max_input_size = ggml_backend_buffer_get_size(input_data);
+ printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
// context for compute tensors without their data
const size_t estimated_compute_size_wo_data = (
@@ -1127,7 +1078,7 @@ int main(int argc, char ** argv) {
// find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
- alloc = ggml_allocr_new_measure(tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1140,9 +1091,10 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ true
);
- size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size;
best_order = gf->order;
@@ -1157,9 +1109,8 @@ int main(int argc, char ** argv) {
"invalid");
// allocate compute tensors
- mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
- alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1172,7 +1123,8 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ false
);
std::vector<llama_token> train_tokens;