Przeglądaj źródła

llama: introduce support for model-embedded sampling parameters (#17120)

Aaron Teo 1 miesiąc temu
rodzic
commit
877566d512

+ 12 - 0
common/arg.cpp

@@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, const std::string & value) {
             const auto sampler_names = string_split<std::string>(value, ';');
             params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1261,6 +1262,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, const std::string & value) {
             params.sampling.temp = std::stof(value);
             params.sampling.temp = std::max(params.sampling.temp, 0.0f);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1268,6 +1270,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
         [](common_params & params, int value) {
             params.sampling.top_k = value;
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1275,6 +1278,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
         [](common_params & params, const std::string & value) {
             params.sampling.top_p = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1282,6 +1286,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
         [](common_params & params, const std::string & value) {
             params.sampling.min_p = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1296,6 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
         [](common_params & params, const std::string & value) {
             params.sampling.xtc_probability = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1303,6 +1309,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
         [](common_params & params, const std::string & value) {
             params.sampling.xtc_threshold = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1321,6 +1328,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             }
             params.sampling.penalty_last_n = value;
             params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1328,6 +1336,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
         [](common_params & params, const std::string & value) {
             params.sampling.penalty_repeat = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1425,6 +1434,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
         [](common_params & params, int value) {
             params.sampling.mirostat = value;
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1432,6 +1442,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
         [](common_params & params, const std::string & value) {
             params.sampling.mirostat_eta = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1439,6 +1450,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
         [](common_params & params, const std::string & value) {
             params.sampling.mirostat_tau = std::stof(value);
+            params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
         }
     ).set_sparam());
     add_opt(common_arg(

+ 55 - 0
common/common.cpp

@@ -8,6 +8,7 @@
 #include "common.h"
 #include "log.h"
 #include "llama.h"
+#include "sampling.h"
 
 #include <algorithm>
 #include <cinttypes>
@@ -949,6 +950,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
 // Model utils
 //
 
+static inline void common_init_sampler_from_model(
+    const llama_model * model,
+    common_params_sampling & sparams) {
+
+    const uint64_t config = sparams.user_sampling_config;
+
+    auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
+        if (config & user_config) return;
+
+        char buf[64] = {0};
+        if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
+            char * end = nullptr;
+            int32_t v = strtol(buf, &end, 10);
+            if (end && end != buf) dst = v;
+        }
+    };
+
+    auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
+        if (config & user_config) return;
+
+        char buf[128] = {0};
+        if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
+            char * end = nullptr;
+            float v = strtof(buf, &end);
+            if (end && end != buf) dst = v;
+        }
+    };
+
+    // Sampling sequence
+    if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
+        char buf[512] = {0};
+        if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
+            const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
+            if (!sampler_names.empty()) {
+                sparams.samplers = common_sampler_types_from_names(sampler_names, true);
+            }
+        }
+    }
+
+    get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K),           sparams.top_k,           common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P),           sparams.top_p,           common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P),           sparams.min_p,           common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD),   sparams.xtc_threshold,   common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP),            sparams.temp,            common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
+    get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N),  sparams.penalty_last_n,  common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT),  sparams.penalty_repeat,  common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
+    get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT),        sparams.mirostat,        common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU),    sparams.mirostat_tau,    common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
+    get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA),    sparams.mirostat_eta,    common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
+}
+
 struct common_init_result common_init_from_params(common_params & params) {
     common_init_result iparams;
     auto mparams = common_model_params_to_llama(params);
@@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) {
         return iparams;
     }
 
+    common_init_sampler_from_model(model, params.sampling);
+
     const llama_vocab * vocab = llama_model_get_vocab(model);
 
     auto cparams = common_context_params_to_llama(params);

+ 18 - 0
common/common.h

@@ -140,6 +140,22 @@ struct common_grammar_trigger {
     llama_token token = LLAMA_TOKEN_NULL;
 };
 
+enum common_params_sampling_config : uint64_t {
+    COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS        = 1 << 0,
+    COMMON_PARAMS_SAMPLING_CONFIG_TOP_K           = 1 << 1,
+    COMMON_PARAMS_SAMPLING_CONFIG_TOP_P           = 1 << 2,
+    COMMON_PARAMS_SAMPLING_CONFIG_MIN_P           = 1 << 3,
+    COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
+    COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD   = 1 << 5,
+    COMMON_PARAMS_SAMPLING_CONFIG_TEMP            = 1 << 6,
+    COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N  = 1 << 7,
+    COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT  = 1 << 8,
+    COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT        = 1 << 9,
+    COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU    = 1 << 10,
+    COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA    = 1 << 11,
+};
+
+
 // sampling parameters
 struct common_params_sampling {
     uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -172,6 +188,8 @@ struct common_params_sampling {
     bool    no_perf            = false; // disable performance metrics
     bool    timing_per_token   = false;
 
+    uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
+
     std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"};     // default sequence breakers for DRY
 
 

+ 14 - 0
gguf-py/gguf/constants.py

@@ -25,6 +25,20 @@ class Keys:
         ALIGNMENT                  = "general.alignment"
         FILE_TYPE                  = "general.file_type"
 
+        # Recommended Sampler Parameters
+        SAMPLING_SEQUENCE           = "general.sampling.sequence"
+        SAMPLING_TOP_K              = "general.sampling.top_k"
+        SAMPLING_TOP_P              = "general.sampling.top_p"
+        SAMPLING_MIN_P              = "general.sampling.min_p"
+        SAMPLING_XTC_PROBABILITY    = "general.sampling.xtc_probability"
+        SAMPLING_XTC_THRESHOLD      = "general.sampling.xtc_threshold"
+        SAMPLING_TEMP               = "general.sampling.temp"
+        SAMPLING_PENALTY_LAST_N     = "general.sampling.penalty_last_n"
+        SAMPLING_PENALTY_REPEAT     = "general.sampling.penalty_repeat"
+        SAMPLING_MIROSTAT           = "general.sampling.mirostat"
+        SAMPLING_MIROSTAT_TAU       = "general.sampling.mirostat_tau"
+        SAMPLING_MIROSTAT_ETA       = "general.sampling.mirostat_eta"
+
         # Authorship Metadata
         NAME                       = "general.name"
         AUTHOR                     = "general.author"

+ 36 - 0
gguf-py/gguf/gguf_writer.py

@@ -496,6 +496,42 @@ class GGUFWriter:
     def add_file_type(self, ftype: int) -> None:
         self.add_uint32(Keys.General.FILE_TYPE, ftype)
 
+    def add_sampling_sequence(self, sequence: str) -> None:
+        self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
+
+    def add_sampling_top_k(self, top_k: int) -> None:
+        self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
+
+    def add_sampling_top_p(self, top_p: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
+
+    def add_sampling_min_p(self, min_p: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
+
+    def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
+
+    def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
+
+    def add_sampling_temp(self, temp: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_TEMP, temp)
+
+    def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
+        self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
+
+    def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
+
+    def add_sampling_mirostat(self, mirostat: int) -> None:
+        self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
+
+    def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
+
+    def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
+        self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
+
     def add_name(self, name: str) -> None:
         self.add_string(Keys.General.NAME, name)
 

+ 85 - 0
gguf-py/gguf/metadata.py

@@ -17,6 +17,20 @@ logger = logging.getLogger("metadata")
 
 @dataclass
 class Metadata:
+    # Recommended Sampler Parameters to be written to GGUF KV Store
+    sampling_sequence: Optional[str] = None
+    sampling_top_k: Optional[int] = None
+    sampling_top_p: Optional[float] = None
+    sampling_min_p: Optional[float] = None
+    sampling_xtc_probability: Optional[float] = None
+    sampling_xtc_threshold: Optional[float] = None
+    sampling_temp: Optional[float] = None
+    sampling_penalty_last_n: Optional[int] = None
+    sampling_penalty_repeat: Optional[float] = None
+    sampling_mirostat: Optional[int] = None
+    sampling_mirostat_tau: Optional[float] = None
+    sampling_mirostat_eta: Optional[float] = None
+
     # Authorship Metadata to be written to GGUF KV Store
     name: Optional[str] = None
     author: Optional[str] = None
@@ -54,15 +68,43 @@ class Metadata:
 
         model_card = Metadata.load_model_card(model_path)
         hf_params = Metadata.load_hf_parameters(model_path)
+        gen_config = Metadata.load_generation_config(model_path)
         # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
 
         # heuristics
         metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
 
+        if gen_config:
+            metadata.sampling_sequence        = gen_config.get("sequence",        metadata.sampling_sequence)
+            metadata.sampling_top_k           = gen_config.get("top_k",           metadata.sampling_top_k)
+            metadata.sampling_top_p           = gen_config.get("top_p",           metadata.sampling_top_p)
+            metadata.sampling_min_p           = gen_config.get("min_p",           metadata.sampling_min_p)
+            metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
+            metadata.sampling_xtc_threshold   = gen_config.get("xtc_threshold",   metadata.sampling_xtc_threshold)
+            metadata.sampling_temp            = gen_config.get("temperature",     metadata.sampling_temp)
+            metadata.sampling_penalty_last_n  = gen_config.get("penalty_last_n",  metadata.sampling_penalty_last_n)
+            metadata.sampling_penalty_repeat  = gen_config.get("penalty_repeat",  metadata.sampling_penalty_repeat)
+            metadata.sampling_mirostat        = gen_config.get("mirostat",        metadata.sampling_mirostat)
+            metadata.sampling_mirostat_tau    = gen_config.get("mirostat_tau",    metadata.sampling_mirostat_tau)
+            metadata.sampling_mirostat_eta    = gen_config.get("mirostat_eta",    metadata.sampling_mirostat_eta)
+
         # Metadata Override File Provided
         # This is based on LLM_KV_NAMES mapping in llama.cpp
         metadata_override = Metadata.load_metadata_override(metadata_override_path)
 
+        metadata.sampling_sequence        = metadata_override.get(Keys.General.SAMPLING_SEQUENCE,        metadata.sampling_sequence)
+        metadata.sampling_top_k           = metadata_override.get(Keys.General.SAMPLING_TOP_K,           metadata.sampling_top_k)
+        metadata.sampling_top_p           = metadata_override.get(Keys.General.SAMPLING_TOP_P,           metadata.sampling_top_p)
+        metadata.sampling_min_p           = metadata_override.get(Keys.General.SAMPLING_MIN_P,           metadata.sampling_min_p)
+        metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
+        metadata.sampling_xtc_threshold   = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD,   metadata.sampling_xtc_threshold)
+        metadata.sampling_temp            = metadata_override.get(Keys.General.SAMPLING_TEMP,            metadata.sampling_temp)
+        metadata.sampling_penalty_last_n  = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N,  metadata.sampling_penalty_last_n)
+        metadata.sampling_penalty_repeat  = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT,  metadata.sampling_penalty_repeat)
+        metadata.sampling_mirostat        = metadata_override.get(Keys.General.SAMPLING_MIROSTAT,        metadata.sampling_mirostat)
+        metadata.sampling_mirostat_tau    = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU,    metadata.sampling_mirostat_tau)
+        metadata.sampling_mirostat_eta    = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA,    metadata.sampling_mirostat_eta)
+
         metadata.name            = metadata_override.get(Keys.General.NAME,            metadata.name)
         metadata.author          = metadata_override.get(Keys.General.AUTHOR,          metadata.author)
         metadata.version         = metadata_override.get(Keys.General.VERSION,         metadata.version)
@@ -172,6 +214,23 @@ class Metadata:
         with open(config_path, "r", encoding="utf-8") as f:
             return json.load(f)
 
+    @staticmethod
+    def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
+        if model_path is None or not model_path.is_dir():
+            return {}
+
+        generation_config_path = model_path / "generation_config.json"
+
+        if not generation_config_path.is_file():
+            return {}
+
+        try:
+            with open(generation_config_path, "r", encoding="utf-8") as f:
+                return json.load(f)
+        except (json.JSONDecodeError, IOError):
+            # not all models have valid generation_config.json
+            return {}
+
     @staticmethod
     def id_to_title(string):
         # Convert capitalization into title form unless acronym or version number
@@ -546,6 +605,32 @@ class Metadata:
 
     def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
         assert self.name is not None
+
+        if self.sampling_sequence is not None:
+            gguf_writer.add_sampling_sequence(self.sampling_sequence)
+        if self.sampling_top_k is not None:
+            gguf_writer.add_sampling_top_k(self.sampling_top_k)
+        if self.sampling_top_p is not None:
+            gguf_writer.add_sampling_top_p(self.sampling_top_p)
+        if self.sampling_min_p is not None:
+            gguf_writer.add_sampling_min_p(self.sampling_min_p)
+        if self.sampling_xtc_probability is not None:
+            gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
+        if self.sampling_xtc_threshold is not None:
+            gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
+        if self.sampling_temp is not None:
+            gguf_writer.add_sampling_temp(self.sampling_temp)
+        if self.sampling_penalty_last_n is not None:
+            gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
+        if self.sampling_penalty_repeat is not None:
+            gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
+        if self.sampling_mirostat is not None:
+            gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
+        if self.sampling_mirostat_tau is not None:
+            gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
+        if self.sampling_mirostat_eta is not None:
+            gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
+
         gguf_writer.add_name(self.name)
 
         if self.author is not None:

+ 18 - 0
include/llama.h

@@ -246,6 +246,21 @@ extern "C" {
         LLAMA_KV_OVERRIDE_TYPE_STR,
     };
 
+    enum llama_model_meta_key {
+        LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
+        LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
+        LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
+        LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
+        LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
+        LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
+        LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
+        LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
+        LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
+        LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
+        LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
+        LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
+    };
+
     struct llama_model_kv_override {
         enum llama_model_kv_override_type tag;
 
@@ -518,6 +533,9 @@ extern "C" {
     // Get the number of metadata key/value pairs
     LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
 
+    // Get sampling metadata key name. Returns nullptr if the key is invalid
+    LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
+
     // Get metadata key name by index
     LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
 

+ 25 - 13
src/llama-arch.cpp

@@ -114,19 +114,31 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
 };
 
 static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
-    { LLM_KV_GENERAL_TYPE,                 "general.type"                          },
-    { LLM_KV_GENERAL_ARCHITECTURE,         "general.architecture"                  },
-    { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version"          },
-    { LLM_KV_GENERAL_ALIGNMENT,            "general.alignment"                     },
-    { LLM_KV_GENERAL_FILE_TYPE,            "general.file_type"                     },
-    { LLM_KV_GENERAL_NAME,                 "general.name"                          },
-    { LLM_KV_GENERAL_AUTHOR,               "general.author"                        },
-    { LLM_KV_GENERAL_VERSION,              "general.version"                       },
-    { LLM_KV_GENERAL_URL,                  "general.url"                           },
-    { LLM_KV_GENERAL_DESCRIPTION,          "general.description"                   },
-    { LLM_KV_GENERAL_LICENSE,              "general.license"                       },
-    { LLM_KV_GENERAL_SOURCE_URL,           "general.source.url"                    },
-    { LLM_KV_GENERAL_SOURCE_HF_REPO,       "general.source.huggingface.repository" },
+    { LLM_KV_GENERAL_TYPE,                     "general.type"                          },
+    { LLM_KV_GENERAL_ARCHITECTURE,             "general.architecture"                  },
+    { LLM_KV_GENERAL_QUANTIZATION_VERSION,     "general.quantization_version"          },
+    { LLM_KV_GENERAL_ALIGNMENT,                "general.alignment"                     },
+    { LLM_KV_GENERAL_FILE_TYPE,                "general.file_type"                     },
+    { LLM_KV_GENERAL_SAMPLING_SEQUENCE,        "general.sampling.sequence"             },
+    { LLM_KV_GENERAL_SAMPLING_TOP_K,           "general.sampling.top_k"                },
+    { LLM_KV_GENERAL_SAMPLING_TOP_P,           "general.sampling.top_p"                },
+    { LLM_KV_GENERAL_SAMPLING_MIN_P,           "general.sampling.min_p"                },
+    { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability"      },
+    { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,   "general.sampling.xtc_threshold"        },
+    { LLM_KV_GENERAL_SAMPLING_TEMP,            "general.sampling.temp"                 },
+    { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,  "general.sampling.penalty_last_n"       },
+    { LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,  "general.sampling.penalty_repeat"       },
+    { LLM_KV_GENERAL_SAMPLING_MIROSTAT,        "general.sampling.mirostat"             },
+    { LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,    "general.sampling.mirostat_tau"         },
+    { LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,    "general.sampling.mirostat_eta"         },
+    { LLM_KV_GENERAL_NAME,                     "general.name"                          },
+    { LLM_KV_GENERAL_AUTHOR,                   "general.author"                        },
+    { LLM_KV_GENERAL_VERSION,                  "general.version"                       },
+    { LLM_KV_GENERAL_URL,                      "general.url"                           },
+    { LLM_KV_GENERAL_DESCRIPTION,              "general.description"                   },
+    { LLM_KV_GENERAL_LICENSE,                  "general.license"                       },
+    { LLM_KV_GENERAL_SOURCE_URL,               "general.source.url"                    },
+    { LLM_KV_GENERAL_SOURCE_HF_REPO,           "general.source.huggingface.repository" },
 
     { LLM_KV_VOCAB_SIZE,                        "%s.vocab_size"                        },
     { LLM_KV_CONTEXT_LENGTH,                    "%s.context_length"                    },

+ 12 - 0
src/llama-arch.h

@@ -123,6 +123,18 @@ enum llm_kv {
     LLM_KV_GENERAL_QUANTIZATION_VERSION,
     LLM_KV_GENERAL_ALIGNMENT,
     LLM_KV_GENERAL_FILE_TYPE,
+    LLM_KV_GENERAL_SAMPLING_SEQUENCE,
+    LLM_KV_GENERAL_SAMPLING_TOP_K,
+    LLM_KV_GENERAL_SAMPLING_TOP_P,
+    LLM_KV_GENERAL_SAMPLING_MIN_P,
+    LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
+    LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
+    LLM_KV_GENERAL_SAMPLING_TEMP,
+    LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
+    LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
+    LLM_KV_GENERAL_SAMPLING_MIROSTAT,
+    LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
+    LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
     LLM_KV_GENERAL_NAME,
     LLM_KV_GENERAL_AUTHOR,
     LLM_KV_GENERAL_VERSION,

+ 18 - 0
src/llama-model.cpp

@@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) {
     return (int)model->gguf_kv.size();
 }
 
+const char * llama_model_meta_key_str(llama_model_meta_key key) {
+    switch (key) {
+        case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE:        return "general.sampling.sequence";
+        case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K:           return "general.sampling.top_k";
+        case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P:           return "general.sampling.top_p";
+        case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P:           return "general.sampling.min_p";
+        case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability";
+        case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD:   return "general.sampling.xtc_threshold";
+        case LLAMA_MODEL_META_KEY_SAMPLING_TEMP:            return "general.sampling.temp";
+        case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N:  return "general.sampling.penalty_last_n";
+        case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT:  return "general.sampling.penalty_repeat";
+        case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT:        return "general.sampling.mirostat";
+        case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU:    return "general.sampling.mirostat_tau";
+        case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA:    return "general.sampling.mirostat_eta";
+        default:                                            return nullptr;
+    }
+}
+
 int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
     if (i < 0 || i >= (int)model->gguf_kv.size()) {
         if (buf_size > 0) {