From 415e99fec27be5a2e4283f1937afd17eb33fbd66 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Fri, 4 Aug 2023 19:29:52 +0800 Subject: Stream save llama context data to file instead of allocating entire buffer upfront (#2488) * added stream saving context data to file to avoid allocating unnecessary amounts of memory * generalised copying state data to file or buffer * added comments explaining how copy_state_data works * fixed trailing whitespaces * fixed save load state example * updated save load state to use public function in llama.cpp * - restored breakage of the llama_copy_state_data API - moved new logic for copying llama state data to internal function * fixed function declaration order * restored save load state example * fixed whitepace * removed unused llama-util.h include * Apply suggestions from code review Co-authored-by: slaren * Apply code review suggestions Co-authored-by: slaren --------- Co-authored-by: slaren --- llama-util.h | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) (limited to 'llama-util.h') diff --git a/llama-util.h b/llama-util.h index 042ebe43..3fc03ce2 100644 --- a/llama-util.h +++ b/llama-util.h @@ -149,6 +149,46 @@ struct llama_file { } }; +// llama_context_data +struct llama_data_context { + virtual void write(const void * src, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t* ptr; + size_t size_written = 0; + + llama_data_buffer_context(uint8_t * p) : ptr(p) {} + + void write(const void * src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file* file; + size_t size_written = 0; + + llama_data_file_context(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { LPSTR buf; -- cgit v1.2.3