summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-09-28 21:42:38 +0200
committerGitHub <noreply@github.com>2023-09-28 22:42:38 +0300
commit16bc66d9479edd5ee12ec734973554d4493c5dfa (patch)
tree4cca787ebd86dd55fd176d27112117c74e9b34c6 /common
parent0512d66670de3f650c579519833c085014b0f200 (diff)
llama.cpp : split llama_context_params into model and context params (#3301)
* llama.cpp : split llama_context_params into model and context params ggml-ci * fix metal build * fix freq_base/scale default to model value * llama-bench : keep the same model between tests when possible * move n_threads to llama_context_params, add n_threads_batch * fix mpi build * remove kv_size(), cuda scratch fixes * remove low-vram option * add n_threads_batch to system info, refactor to get_system_info() * add documentation about --threads-batch to the READMEs * llama-bench fix * main : fix rope freq/scale warning * llama.cpp : add llama_get_model common : add llama_tokenize from model * remove duplicated ctx/model functions ggml-ci * cuda : print total VRAM used
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp112
-rw-r--r--common/common.h12
-rw-r--r--common/train.cpp10
3 files changed, 87 insertions, 47 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 8764a7be..6e8c08cb 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
}
+ } else if (arg == "-tb" || arg == "--threads-batch") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_batch = std::stoi(argv[i]);
+ if (params.n_threads_batch <= 0) {
+ params.n_threads_batch = std::thread::hardware_concurrency();
+ }
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
@@ -452,12 +461,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
#endif // GGML_USE_CUBLAS
- } else if (arg == "--low-vram" || arg == "-lv") {
-#ifdef GGML_USE_CUBLAS
- params.low_vram = true;
-#else
- fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
-#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--numa") {
@@ -630,7 +633,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" (can be specified more than once for multiple prompts).\n");
printf(" --color colorise output to distinguish prompt and user input from generations\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
- printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
+ printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
+ printf(" -tb N, --threads-batch N\n");
+ printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -p PROMPT, --prompt PROMPT\n");
printf(" prompt to start generation with (default: empty)\n");
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
@@ -645,7 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -f FNAME, --file FNAME\n");
printf(" prompt file to start generation.\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
- printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
@@ -705,7 +710,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
- printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
@@ -726,6 +730,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf("\n");
}
+std::string get_system_info(const gpt_params & params) {
+ std::ostringstream os;
+
+ os << "system_info: n_threads = " << params.n_threads;
+ if (params.n_threads_batch != -1) {
+ os << " (n_threads_batch = " << params.n_threads_batch << ")";
+ }
+ os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
+
+ return os.str();
+}
+
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
@@ -749,40 +765,50 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
// Model utils
//
-struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
- auto lparams = llama_context_default_params();
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
+ auto mparams = llama_model_default_params();
- lparams.n_ctx = params.n_ctx;
- lparams.n_batch = params.n_batch;
if (params.n_gpu_layers != -1) {
- lparams.n_gpu_layers = params.n_gpu_layers;
+ mparams.n_gpu_layers = params.n_gpu_layers;
}
- lparams.main_gpu = params.main_gpu;
- lparams.tensor_split = params.tensor_split;
- lparams.low_vram = params.low_vram;
- lparams.mul_mat_q = params.mul_mat_q;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
- lparams.logits_all = params.logits_all;
- lparams.embedding = params.embedding;
- lparams.rope_freq_base = params.rope_freq_base;
- lparams.rope_freq_scale = params.rope_freq_scale;
-
- return lparams;
+ mparams.main_gpu = params.main_gpu;
+ mparams.tensor_split = params.tensor_split;
+ mparams.use_mmap = params.use_mmap;
+ mparams.use_mlock = params.use_mlock;
+
+ return mparams;
+}
+
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+ auto cparams = llama_context_default_params();
+
+ cparams.n_ctx = params.n_ctx;
+ cparams.n_batch = params.n_batch;
+ cparams.n_threads = params.n_threads;
+ cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+ cparams.mul_mat_q = params.mul_mat_q;
+ cparams.seed = params.seed;
+ cparams.f16_kv = params.memory_f16;
+ cparams.logits_all = params.logits_all;
+ cparams.embedding = params.embedding;
+ cparams.rope_freq_base = params.rope_freq_base;
+ cparams.rope_freq_scale = params.rope_freq_scale;
+
+ return cparams;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
- auto lparams = llama_context_params_from_gpt_params(params);
+ auto mparams = llama_model_params_from_gpt_params(params);
- llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
}
- llama_context * lctx = llama_new_context_with_model(model, lparams);
+ auto cparams = llama_context_params_from_gpt_params(params);
+
+ llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
@@ -815,7 +841,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
- llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
+ llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
}
@@ -828,16 +854,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
//
std::vector<llama_token> llama_tokenize(
- struct llama_context * ctx,
+ const struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos) {
+ return llama_tokenize(llama_get_model(ctx), text, add_bos);
+}
+
+std::vector<llama_token> llama_tokenize(
+ const struct llama_model * model,
const std::string & text,
bool add_bos) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
- n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
+ n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
- int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
+ int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
@@ -847,10 +880,10 @@ std::vector<llama_token> llama_tokenize(
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
- const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
- int check = llama_token_to_piece(ctx, token, result.data(), result.size());
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
@@ -905,7 +938,7 @@ llama_token llama_sample_token(
std::vector<llama_token_data> & candidates,
int idx) {
const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
@@ -1191,7 +1224,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
#endif // NDEBUG
fprintf(stream, "model_desc: %s\n", model_desc);
- fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx));
+ fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
#ifdef __OPTIMIZE__
fprintf(stream, "optimize: true\n");
@@ -1258,7 +1291,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
- fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);
diff --git a/common/common.h b/common/common.h
index 64601f99..0e2d3fa6 100644
--- a/common/common.h
+++ b/common/common.h
@@ -36,6 +36,7 @@ int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
@@ -95,7 +96,6 @@ struct gpt_params {
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
- bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
@@ -126,6 +126,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+std::string get_system_info(const gpt_params & params);
+
std::string gpt_random_prompt(std::mt19937 & rng);
void process_escapes(std::string& input);
@@ -135,6 +137,7 @@ void process_escapes(std::string& input);
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//
@@ -144,7 +147,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
// tokenizes a string into a vector of tokens
// should work similar to Python's `tokenizer.encode`
std::vector<llama_token> llama_tokenize(
- struct llama_context * ctx,
+ const struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos);
+
+std::vector<llama_token> llama_tokenize(
+ const struct llama_model * model,
const std::string & text,
bool add_bos);
diff --git a/common/train.cpp b/common/train.cpp
index 4a128096..35a4cf9e 100644
--- a/common/train.cpp
+++ b/common/train.cpp
@@ -858,7 +858,7 @@ size_t tokenize_file(
out_tokens.resize(buf.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(
- lctx,
+ llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
@@ -867,7 +867,7 @@ size_t tokenize_file(
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
- lctx,
+ llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
@@ -920,7 +920,7 @@ size_t tokenize_file(
size_t found_max_sample_size = 0;
size_t max_token_text_size = 0;
- int n_vocab = llama_n_vocab(lctx);
+ int n_vocab = llama_n_vocab(llama_get_model(lctx));
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
@@ -961,7 +961,7 @@ size_t tokenize_file(
// tokenize the sample
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
- int n_tokens = llama_tokenize(lctx,
+ int n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
@@ -969,7 +969,7 @@ size_t tokenize_file(
false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
- n_tokens = llama_tokenize(lctx,
+ n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),