summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp69
1 files changed, 50 insertions, 19 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 4db12ee1..4b983e5f 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -1,8 +1,9 @@
#define LLAMA_API_INTERNAL
#include "sampling.h"
+#include "llama-vocab.h"
#include <random>
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
+struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
result->params = params;
@@ -36,13 +37,32 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
}
result->grammar = grammar;
}
-
result->prev.resize(params.n_prev);
result->n_valid = 0;
+ // init DRY
+ for (const auto& cnstr : params.samplers_sequence)
+ {
+ switch (cnstr)
+ {
+ case llama_sampler_type::DRY:
+ {
+ std::vector<const char*> c_breakers;
+ c_breakers.reserve(params.dry_sequence_breakers.size());
+ for (const auto& str : params.dry_sequence_breakers)
+ {
+ c_breakers.push_back(str.c_str());
+ }
+ result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size());
+
+ break;
+ }
+ default:
+ break;
+ }
+ }
llama_sampling_set_rng_seed(result, params.seed);
-
return result;
}
@@ -50,7 +70,8 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
-
+ if (ctx->smpl !=NULL)
+ llama_sampler_dry_free(ctx->smpl);
delete ctx;
}
@@ -75,6 +96,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_valid = 0;
+ llama_sampler_dry_reset(ctx->smpl);
}
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -95,6 +117,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
}
dst->prev = src->prev;
+ dst->smpl = llama_sampler_dry_clone(src->smpl);
}
llama_token llama_sampling_last(llama_sampling_context * ctx) {
@@ -149,6 +172,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
switch (sampler_type) {
+ case llama_sampler_type::DRY: return "dry";
case llama_sampler_type::TOP_K: return "top_k";
case llama_sampler_type::TFS_Z: return "tfs_z";
case llama_sampler_type::TYPICAL_P: return "typical_p";
@@ -163,6 +187,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
+ {"dry", llama_sampler_type::DRY},
{"top_k", llama_sampler_type::TOP_K},
{"top_p", llama_sampler_type::TOP_P},
{"typical_p", llama_sampler_type::TYPICAL_P},
@@ -176,6 +201,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
// since samplers names are written multiple ways
// make it ready for both system names and input names
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
+ {"dry", llama_sampler_type::DRY},
{"top-k", llama_sampler_type::TOP_K},
{"top-p", llama_sampler_type::TOP_P},
{"nucleus", llama_sampler_type::TOP_P},
@@ -215,6 +241,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
std::unordered_map<char, llama_sampler_type> sampler_name_map {
+ {'d', llama_sampler_type::DRY},
{'k', llama_sampler_type::TOP_K},
{'p', llama_sampler_type::TOP_P},
{'y', llama_sampler_type::TYPICAL_P},
@@ -238,25 +265,28 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
// no reasons to expose this function in header
static void sampler_queue(
- struct llama_context * ctx_main,
- const llama_sampling_params & params,
- llama_token_data_array & cur_p,
- size_t min_keep) {
- const float temp = params.temp;
- const float dynatemp_range = params.dynatemp_range;
+ struct llama_context* ctx_main,
+ const llama_sampling_params& params,
+ llama_sampling_context * ctx_sampling,
+ llama_token_data_array& cur_p,
+ size_t min_keep) {
+ const float temp = params.temp;
+ const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent;
- const int32_t top_k = params.top_k;
- const float top_p = params.top_p;
- const float min_p = params.min_p;
- const float tfs_z = params.tfs_z;
- const float typical_p = params.typical_p;
- const float xtc_probability = params.xtc_probability;
- const float xtc_threshold = params.xtc_threshold;
- const float top_n_sigma = params.top_n_sigma;
+ const int32_t top_k = params.top_k;
+ const float top_p = params.top_p;
+ const float min_p = params.min_p;
+ const float tfs_z = params.tfs_z;
+ const float typical_p = params.typical_p;
+ const float xtc_probability = params.xtc_probability;
+ const float xtc_threshold = params.xtc_threshold;
+ const float top_n_sigma = params.top_n_sigma;
+
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
for (auto sampler_type : samplers_sequence) {
switch (sampler_type) {
+ case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break;
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
@@ -317,7 +347,7 @@ static llama_token llama_sampling_sample_impl(
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);
- sampler_queue(ctx_main, params, cur_p, min_keep);
+ sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
@@ -472,4 +502,5 @@ void llama_sampling_accept(
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
+ llama_sampler_dry_accept(ctx_sampling->smpl, id);
}