Browse Source

llama: fix integer type consistency in split helpers (#18894)

* llama: fix integer type consistency in split helpers

* llama: apply minor style fixes

* llama: remove trailing whitespace
Jakkala Mahesh 5 days ago
parent
commit
24bc238303
2 changed files with 46 additions and 16 deletions
  1. 2 2
      include/llama.h
  2. 44 14
      src/llama.cpp

+ 2 - 2
include/llama.h

@@ -1476,12 +1476,12 @@ extern "C" {
     /// @details Build a split GGUF final path for this chunk.
     /// @details Build a split GGUF final path for this chunk.
     ///          llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
     ///          llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
     //  Returns the split_path length.
     //  Returns the split_path length.
-    LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count);
+    LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count);
 
 
     /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
     /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
     ///          llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
     ///          llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
     //  Returns the split_prefix length.
     //  Returns the split_prefix length.
-    LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
+    LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count);
 
 
     // Print system information
     // Print system information
     LLAMA_API const char * llama_print_system_info(void);
     LLAMA_API const char * llama_print_system_info(void);

+ 44 - 14
src/llama.cpp

@@ -1095,25 +1095,55 @@ int32_t llama_chat_apply_template(
 // model split
 // model split
 //
 //
 
 
-int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+int32_t llama_split_path(
+    char * split_path,
+    size_t maxlen,
+    const char * path_prefix,
+    int32_t split_no,
+    int32_t split_count) {
+
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
-    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
-        return strlen(split_path);
+
+    const int written = snprintf(
+        split_path,
+        maxlen,
+        SPLIT_PATH_FORMAT,
+        path_prefix,
+        split_no + 1,
+        split_count
+    );
+
+    if (written < 0 || (size_t) written >= maxlen) {
+        return 0;
     }
     }
-    return 0;
+
+    return (int32_t) written;
 }
 }
 
 
-int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
-    std::string str_split_path(split_path);
+int32_t llama_split_prefix(
+    char * split_prefix,
+    size_t maxlen,
+    const char * split_path,
+    int32_t split_no,
+    int32_t split_count) {
+
+    const std::string str_split_path(split_path);
+
     char postfix[32];
     char postfix[32];
-    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
-    std::string str_postfix(postfix);
-
-    // check if split_prefix ends with postfix
-    int size_prefix = str_split_path.size() - str_postfix.size();
-    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
-        return size_prefix;
+    snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count);
+
+    const std::string str_postfix(postfix);
+    if (str_split_path.size() <= str_postfix.size()) {
+        return 0;
+    }
+
+    const size_t size_prefix = str_split_path.size() - str_postfix.size();
+
+    if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) {
+        const size_t copy_len = std::min(size_prefix + 1, maxlen);
+        snprintf(split_prefix, copy_len, "%s", split_path);
+
+        return (int32_t) size_prefix;
     }
     }
 
 
     return 0;
     return 0;