summaryrefslogtreecommitdiff
path: root/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py')
-rw-r--r--examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py b/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py
index a527d615..351e7bc2 100644
--- a/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py
+++ b/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py
@@ -47,10 +47,13 @@ LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys"
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"
-LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
-LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
-LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
-LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
+LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"
+LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"
+LLM_KV_TRAINING_TYPE = "training.type"
+LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
+LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
+LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
+LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
class Tensor:
def __init__(self, dtype='f', ne=None):
@@ -460,6 +463,7 @@ class Checkpoint:
gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
gguf_writer.add_layer_norm_rms_eps(1e-5)
gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION, 0)
+ gguf_writer.add_string(LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL)
gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT, self.train_samples)
gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT, self.train_tokens)