summaryrefslogtreecommitdiff
path: root/common/common.h
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-01-31 17:30:17 +0200
committerGitHub <noreply@github.com>2024-01-31 17:30:17 +0200
commit5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3 (patch)
tree3ef8dc640d5c08466309c09a8ac2963bb760af06 /common/common.h
parentefb7bdbbd061d087c788598b97992c653f992ddd (diff)
llama : remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD (#5240)
* llama : remove LLAMA_MAX_DEVICES from llama.h ggml-ci * Update llama.cpp Co-authored-by: slaren <slarengh@gmail.com> * server : remove LLAMA_MAX_DEVICES ggml-ci * llama : remove LLAMA_SUPPORTS_GPU_OFFLOAD ggml-ci * train : remove LLAMA_SUPPORTS_GPU_OFFLOAD * readme : add deprecation notice * readme : change deprecation notice to "remove" and fix url * llama : remove gpu includes from llama.h ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'common/common.h')
-rw-r--r--common/common.h68
1 files changed, 34 insertions, 34 deletions
diff --git a/common/common.h b/common/common.h
index 214a379b..24a99d72 100644
--- a/common/common.h
+++ b/common/common.h
@@ -43,40 +43,40 @@ extern char const *LLAMA_BUILD_TARGET;
int32_t get_num_physical_cores();
struct gpt_params {
- uint32_t seed = -1; // RNG seed
-
- int32_t n_threads = get_num_physical_cores();
- int32_t n_threads_draft = -1;
- int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
- int32_t n_threads_batch_draft = -1;
- int32_t n_predict = -1; // new tokens to predict
- int32_t n_ctx = 512; // context size
- int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_draft = 8; // number of tokens to draft during speculative decoding
- int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
- int32_t n_parallel = 1; // number of parallel sequences to decode
- int32_t n_sequences = 1; // number of sequences to decode
- float p_accept = 0.5f; // speculative decoding accept probability
- float p_split = 0.1f; // speculative decoding split probability
- int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
- int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
- llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
- int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
- float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
- int32_t n_beams = 0; // if non-zero then use beam search of given width.
- int32_t grp_attn_n = 1; // group-attention factor
- int32_t grp_attn_w = 512; // group-attention width
- int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
- float rope_freq_base = 0.0f; // RoPE base frequency
- float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
- float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
- float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
- float yarn_beta_fast = 32.0f; // YaRN low correction dim
- float yarn_beta_slow = 1.0f; // YaRN high correction dim
- int32_t yarn_orig_ctx = 0; // YaRN original context length
- int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
- // pinging @cebtenzzre
+ uint32_t seed = -1; // RNG seed
+
+ int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_draft = -1;
+ int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
+ int32_t n_threads_batch_draft = -1;
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_ctx = 512; // context size
+ int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_draft = 8; // number of tokens to draft during speculative decoding
+ int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
+ int32_t n_parallel = 1; // number of parallel sequences to decode
+ int32_t n_sequences = 1; // number of sequences to decode
+ float p_accept = 0.5f; // speculative decoding accept probability
+ float p_split = 0.1f; // speculative decoding split probability
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
+ int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+ llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
+ float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
+ int32_t n_beams = 0; // if non-zero then use beam search of given width.
+ int32_t grp_attn_n = 1; // group-attention factor
+ int32_t grp_attn_w = 512; // group-attention width
+ int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
+ float rope_freq_base = 0.0f; // RoPE base frequency
+ float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
+ float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
+ float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
+ float yarn_beta_fast = 32.0f; // YaRN low correction dim
+ float yarn_beta_slow = 1.0f; // YaRN high correction dim
+ int32_t yarn_orig_ctx = 0; // YaRN original context length
+ int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
+ // pinging @cebtenzzre
// // sampling parameters
struct llama_sampling_params sparams;