summaryrefslogtreecommitdiff
path: root/examples/llava/llava-cli.cpp
diff options
context:
space:
mode:
authorcpumaxx <163466046+cpumaxx@users.noreply.github.com>2024-04-29 07:34:24 -0700
committerGitHub <noreply@github.com>2024-04-29 17:34:24 +0300
commitffe666572f98a686b17a2cd1dbf4c0a982e5ac0a (patch)
tree062ed2b2706163cdb2006b0204c4589e7da4f75a /examples/llava/llava-cli.cpp
parent24affa7db3c9db148854b0ab4fd63de8bca7d898 (diff)
llava-cli : multiple images (#6969)
Co-authored-by: root <root@nenya.lothlorien.ca>
Diffstat (limited to 'examples/llava/llava-cli.cpp')
-rw-r--r--examples/llava/llava-cli.cpp65
1 files changed, 37 insertions, 28 deletions
diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp
index a44c6cd7..157a680b 100644
--- a/examples/llava/llava-cli.cpp
+++ b/examples/llava/llava-cli.cpp
@@ -113,11 +113,11 @@ struct llava_context {
};
static void show_additional_info(int /*argc*/, char ** argv) {
- LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
+ LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
-static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
+static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) {
// load and preprocess the image
llava_image_embed * embed = NULL;
@@ -133,9 +133,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
}
params->prompt = remove_image_from_prompt(prompt);
} else {
- embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
+ embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str());
if (!embed) {
- LOG_TEE("%s: is %s really an image file?\n", __func__, params->image.c_str());
+ fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
return NULL;
}
}
@@ -207,17 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
printf("\n");
}
-
-static struct llava_context * llava_init(gpt_params * params) {
- const char * clip_path = params->mmproj.c_str();
-
- auto prompt = params->prompt;
- if (prompt.empty()) {
- prompt = "describe the image in detail.";
- }
-
- auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
-
+static struct llama_model * llava_init(gpt_params * params) {
llama_backend_init();
llama_numa_init(params->numa);
@@ -228,6 +218,19 @@ static struct llava_context * llava_init(gpt_params * params) {
LOG_TEE("%s: error: unable to load model\n" , __func__);
return NULL;
}
+ return model;
+}
+
+static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) {
+ const char * clip_path = params->mmproj.c_str();
+
+ auto prompt = params->prompt;
+ if (prompt.empty()) {
+ prompt = "describe the image in detail.";
+ }
+
+ auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
+
llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
@@ -286,24 +289,30 @@ int main(int argc, char ** argv) {
show_additional_info(argc, argv);
return 1;
}
-
- auto ctx_llava = llava_init(&params);
- if (ctx_llava == NULL) {
- LOG_TEE("%s: error: failed to init llava\n", __func__);
+ auto model = llava_init(&params);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
return 1;
}
- auto image_embed = load_image(ctx_llava, &params);
- if (!image_embed) {
- return 1;
- }
+ for (auto & image : params.image) {
+ auto ctx_llava = llava_init_context(&params, model);
- // process the prompt
- process_prompt(ctx_llava, image_embed, &params, params.prompt);
+ auto image_embed = load_image(ctx_llava, &params, image);
+ if (!image_embed) {
+ std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
+ return 1;
+ }
+
+ // process the prompt
+ process_prompt(ctx_llava, image_embed, &params, params.prompt);
- llama_print_timings(ctx_llava->ctx_llama);
+ llama_print_timings(ctx_llava->ctx_llama);
+ llava_image_embed_free(image_embed);
+ ctx_llava->model = NULL;
+ llava_free(ctx_llava);
+ }
+ llama_free_model(model);
- llava_image_embed_free(image_embed);
- llava_free(ctx_llava);
return 0;
}