summaryrefslogtreecommitdiff
path: root/examples/llava/clip.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/llava/clip.cpp')
-rw-r--r--examples/llava/clip.cpp152
1 files changed, 81 insertions, 71 deletions
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 9129052a..ccd0d85a 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -367,7 +367,7 @@ struct clip_ctx {
ggml_backend_buffer_t params_buffer = NULL;
ggml_backend_buffer_t compute_buffer = NULL;
ggml_backend_t backend = NULL;
- ggml_allocr * compute_alloc = NULL;
+ ggml_gallocr_t compute_alloc = NULL;
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
@@ -405,31 +405,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
- ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
-
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- float * data = (float *)malloc(ggml_nbytes(inp_raw));
-
- for (size_t i = 0; i < imgs->size; i++) {
- const int nx = imgs->data[i].nx;
- const int ny = imgs->data[i].ny;
- GGML_ASSERT(nx == image_size && ny == image_size);
-
- const int n = nx * ny;
-
- for (int b = 0; b < batch_size; b++) {
- for (int k = 0; k < 3; k++) {
- for (int y = 0; y < ny; y++) {
- for (int x = 0; x < nx; x++) {
- data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
- }
- }
- }
- }
- }
- ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
- free(data);
- }
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@@ -438,13 +415,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// concat class_embeddings and patch_embeddings
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
- ggml_allocr_alloc(ctx->compute_alloc, embeddings);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- void* zero_mem = malloc(ggml_nbytes(embeddings));
- memset(zero_mem, 0, ggml_nbytes(embeddings));
- ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
- free(zero_mem);
- }
+ ggml_set_name(embeddings, "embeddings");
+ ggml_set_input(embeddings);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
@@ -453,15 +425,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
- ggml_allocr_alloc(ctx->compute_alloc, positions);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- int* positions_data = (int*)malloc(ggml_nbytes(positions));
- for (int i = 0; i < num_positions; i++) {
- positions_data[i] = i;
- }
- ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
- free(positions_data);
- }
+ ggml_set_name(positions, "positions");
+ ggml_set_input(positions);
embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
@@ -560,15 +525,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
- ggml_allocr_alloc(ctx->compute_alloc, patches);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- int* patches_data = (int*)malloc(ggml_nbytes(patches));
- for (int i = 0; i < num_patches; i++) {
- patches_data[i] = i + 1;
- }
- ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
- free(patches_data);
- }
+ ggml_set_name(patches, "patches");
+ ggml_set_input(patches);
// shape [1, 576, 1024]
// ne is whcn, ne = [1024, 576, 1, 1]
@@ -809,7 +767,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
// data
- size_t buffer_size = 0;
+ size_t model_size = 0;
{
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
@@ -817,7 +775,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
enum ggml_type type = gguf_get_tensor_type(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
size_t tensor_size = ggml_nbytes(cur);
- buffer_size += tensor_size;
+ model_size += tensor_size;
if (verbosity >= 3) {
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
__func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
@@ -825,8 +783,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
}
- buffer_size += n_tensors * 128 /* CLIP PADDING */;
-
clip_ctx * new_clip = new clip_ctx;
// update projector type
@@ -886,12 +842,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
- printf("%s: model size: %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
+ printf("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
}
}
- printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
+ printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
// load tensors
{
@@ -925,12 +881,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
// alloc memory and offload data
- new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
- ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
+ new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
- ggml_allocr_alloc(alloc, cur);
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
fin.seekg(offset, std::ios::beg);
if (!fin) {
@@ -949,7 +903,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
}
}
- ggml_allocr_free(alloc);
fin.close();
}
@@ -1077,15 +1030,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// measure mem requirement and allocate
{
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
- new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
+ new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
clip_image_f32_batch batch;
batch.size = 1;
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
- size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
- ggml_allocr_free(new_clip->compute_alloc);
- new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
- new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
-
+ ggml_gallocr_reserve(new_clip->compute_alloc, gf);
+ size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
}
@@ -1267,12 +1217,72 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
}
- // reset alloc buffer to clean the memory from previous invocations
- ggml_allocr_reset(ctx->compute_alloc);
-
// build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
- ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
+ ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+ // set inputs
+ const auto & model = ctx->vision_model;
+ const auto & hparams = model.hparams;
+ const int image_size = hparams.image_size;
+ const int patch_size = hparams.patch_size;
+ const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
+ const int num_positions = num_patches + 1;
+
+ {
+ struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+ float * data = (float *)malloc(ggml_nbytes(inp_raw));
+
+ for (size_t i = 0; i < imgs->size; i++) {
+ const int nx = imgs->data[i].nx;
+ const int ny = imgs->data[i].ny;
+ GGML_ASSERT(nx == image_size && ny == image_size);
+
+ const int n = nx * ny;
+
+ for (int b = 0; b < batch_size; b++) {
+ for (int k = 0; k < 3; k++) {
+ for (int y = 0; y < ny; y++) {
+ for (int x = 0; x < nx; x++) {
+ data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+ }
+ }
+ }
+ }
+ }
+ ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+ free(data);
+ }
+
+ {
+ struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+
+ void* zero_mem = malloc(ggml_nbytes(embeddings));
+ memset(zero_mem, 0, ggml_nbytes(embeddings));
+ ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+ free(zero_mem);
+ }
+
+ {
+ struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+ int* positions_data = (int*)malloc(ggml_nbytes(positions));
+ for (int i = 0; i < num_positions; i++) {
+ positions_data[i] = i;
+ }
+ ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+ free(positions_data);
+ }
+
+ {
+ struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+ int* patches_data = (int*)malloc(ggml_nbytes(patches));
+ for (int i = 0; i < num_patches; i++) {
+ patches_data[i] = i + 1;
+ }
+ ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+ free(patches_data);
+ }
if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);