diff options
author | l3utterfly <gc.pthzfoldr@gmail.com> | 2023-10-06 18:47:59 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 13:47:59 +0300 |
commit | 16820a5a0d885113f21021ce934f0b0027b9d69a (patch) | |
tree | 012fef425d65e6ac699613cd5f4167d1639fcf59 | |
parent | 04b2f4386eda0264287156104cbf9d1b87895422 (diff) |
llama : correct hparams comparison (#3446)
* fixed floating point comparison issues
* updated implementation for hparam comparison to handle inf and NaN
* fixed code review comments
* minor simplification
* rename is_float_eq -> is_float_close
---------
Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
-rw-r--r-- | llama.cpp | 40 |
1 files changed, 39 insertions, 1 deletions
@@ -125,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std:: } s = std::move(result); } + +static bool is_float_close(float a, float b, float abs_tol) { + // Check for non-negative tolerance + if (abs_tol < 0.0) { + throw std::invalid_argument("Tolerance must be non-negative"); + } + + // Exact equality check + if (a == b) { + return true; + } + + // Check for infinities + if (std::isinf(a) || std::isinf(b)) { + return false; + } + + // Regular comparison using the provided absolute tolerance + return std::fabs(b - a) <= abs_tol; +} + #ifdef GGML_USE_CPU_HBM #include <hbwmalloc.h> #endif @@ -969,7 +990,24 @@ struct llama_hparams { float rope_freq_scale_train; bool operator!=(const llama_hparams & other) const { - return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; + if (this->n_ctx_train != other.n_ctx_train) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; + + const float EPSILON = 1e-9; + + if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; + if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; + if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + + return false; } uint32_t n_gqa() const { |