Ver Fonte

server: add --media-path for local media files (#17697)

* server: add --media-path for local media files

* remove unused fn
Xuan-Son Nguyen há 1 mês atrás
pai
commit
13628d8bdb

+ 17 - 0
common/arg.cpp

@@ -2488,12 +2488,29 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         "path to save slot kv cache (default: disabled)",
         [](common_params & params, const std::string & value) {
             params.slot_save_path = value;
+            if (!fs_is_directory(params.slot_save_path)) {
+                throw std::invalid_argument("not a directory: " + value);
+            }
             // if doesn't end with DIRECTORY_SEPARATOR, add it
             if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
                 params.slot_save_path += DIRECTORY_SEPARATOR;
             }
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--media-path"}, "PATH",
+        "directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)",
+        [](common_params & params, const std::string & value) {
+            params.media_path = value;
+            if (!fs_is_directory(params.media_path)) {
+                throw std::invalid_argument("not a directory: " + value);
+            }
+            // if doesn't end with DIRECTORY_SEPARATOR, add it
+            if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) {
+                params.media_path += DIRECTORY_SEPARATOR;
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--models-dir"}, "PATH",
         "directory containing models for the router server (default: disabled)",

+ 11 - 2
common/common.cpp

@@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
 
 // Validate if a filename is safe to use
 // To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
-bool fs_validate_filename(const std::string & filename) {
+bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
     if (!filename.length()) {
         // Empty filename invalid
         return false;
@@ -754,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
             || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
             || c == 0xFFFD // Replacement Character (UTF-8)
             || c == 0xFEFF // Byte Order Mark (BOM)
-            || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
+            || c == ':' || c == '*' // Illegal characters
             || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
             return false;
         }
+        if (!allow_subdirs && (c == '/' || c == '\\')) {
+            // Subdirectories not allowed, reject path separators
+            return false;
+        }
     }
 
     // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -859,6 +863,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
 #endif // _WIN32
 }
 
+bool fs_is_directory(const std::string & path) {
+    std::filesystem::path dir(path);
+    return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
+}
+
 std::string fs_get_cache_directory() {
     std::string cache_directory = "";
     auto ensure_trailing_slash = [](std::string p) {

+ 3 - 1
common/common.h

@@ -485,6 +485,7 @@ struct common_params {
     bool log_json = false;
 
     std::string slot_save_path;
+    std::string media_path; // path to directory for loading media files
 
     float slot_prompt_similarity = 0.1f;
 
@@ -635,8 +636,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
 // Filesystem utils
 //
 
-bool fs_validate_filename(const std::string & filename);
+bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
 bool fs_create_directory_with_parents(const std::string & path);
+bool fs_is_directory(const std::string & path);
 
 std::string fs_get_cache_directory();
 std::string fs_get_cache_file(const std::string & filename);

+ 64 - 35
tools/server/server-common.cpp

@@ -11,6 +11,7 @@
 
 #include <random>
 #include <sstream>
+#include <fstream>
 
 json format_error_response(const std::string & message, const enum error_type type) {
     std::string type_str;
@@ -774,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) {
     return llama_params;
 }
 
+// media_path always end with '/', see arg.cpp
+static void handle_media(
+        std::vector<raw_buffer> & out_files,
+        json & media_obj,
+        const std::string & media_path) {
+    std::string url = json_value(media_obj, "url", std::string());
+    if (string_starts_with(url, "http")) {
+        // download remote image
+        // TODO @ngxson : maybe make these params configurable
+        common_remote_params params;
+        params.headers.push_back("User-Agent: llama.cpp/" + build_info);
+        params.max_size = 1024 * 1024 * 10; // 10MB
+        params.timeout  = 10; // seconds
+        SRV_INF("downloading image from '%s'\n", url.c_str());
+        auto res = common_remote_get_content(url, params);
+        if (200 <= res.first && res.first < 300) {
+            SRV_INF("downloaded %ld bytes\n", res.second.size());
+            raw_buffer data;
+            data.insert(data.end(), res.second.begin(), res.second.end());
+            out_files.push_back(data);
+        } else {
+            throw std::runtime_error("Failed to download image");
+        }
+
+    } else if (string_starts_with(url, "file://")) {
+        if (media_path.empty()) {
+            throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified");
+        }
+        // load local image file
+        std::string file_path = url.substr(7); // remove "file://"
+        raw_buffer data;
+        if (!fs_validate_filename(file_path, true)) {
+            throw std::invalid_argument("file path is not allowed: " + file_path);
+        }
+        SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str());
+        std::ifstream file(media_path + file_path, std::ios::binary);
+        if (!file) {
+            throw std::invalid_argument("file does not exist or cannot be opened: " + file_path);
+        }
+        data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
+        out_files.push_back(data);
+
+    } else {
+        // try to decode base64 image
+        std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
+        if (parts.size() != 2) {
+            throw std::runtime_error("Invalid url value");
+        } else if (!string_starts_with(parts[0], "data:image/")) {
+            throw std::runtime_error("Invalid url format: " + parts[0]);
+        } else if (!string_ends_with(parts[0], "base64")) {
+            throw std::runtime_error("url must be base64 encoded");
+        } else {
+            auto base64_data = parts[1];
+            auto decoded_data = base64_decode(base64_data);
+            out_files.push_back(decoded_data);
+        }
+    }
+}
+
 // used by /chat/completions endpoint
 json oaicompat_chat_params_parse(
     json & body, /* openai api json semantics */
@@ -860,41 +920,8 @@ json oaicompat_chat_params_parse(
                     throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
                 }
 
-                json image_url  = json_value(p, "image_url", json::object());
-                std::string url = json_value(image_url, "url", std::string());
-                if (string_starts_with(url, "http")) {
-                    // download remote image
-                    // TODO @ngxson : maybe make these params configurable
-                    common_remote_params params;
-                    params.headers.push_back("User-Agent: llama.cpp/" + build_info);
-                    params.max_size = 1024 * 1024 * 10; // 10MB
-                    params.timeout  = 10; // seconds
-                    SRV_INF("downloading image from '%s'\n", url.c_str());
-                    auto res = common_remote_get_content(url, params);
-                    if (200 <= res.first && res.first < 300) {
-                        SRV_INF("downloaded %ld bytes\n", res.second.size());
-                        raw_buffer data;
-                        data.insert(data.end(), res.second.begin(), res.second.end());
-                        out_files.push_back(data);
-                    } else {
-                        throw std::runtime_error("Failed to download image");
-                    }
-
-                } else {
-                    // try to decode base64 image
-                    std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
-                    if (parts.size() != 2) {
-                        throw std::invalid_argument("Invalid image_url.url value");
-                    } else if (!string_starts_with(parts[0], "data:image/")) {
-                        throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
-                    } else if (!string_ends_with(parts[0], "base64")) {
-                        throw std::invalid_argument("image_url.url must be base64 encoded");
-                    } else {
-                        auto base64_data = parts[1];
-                        auto decoded_data = base64_decode(base64_data);
-                        out_files.push_back(decoded_data);
-                    }
-                }
+                json image_url = json_value(p, "image_url", json::object());
+                handle_media(out_files, image_url, opt.media_path);
 
                 // replace this chunk with a marker
                 p["type"] = "text";
@@ -916,6 +943,8 @@ json oaicompat_chat_params_parse(
                 auto decoded_data = base64_decode(data); // expected to be base64 encoded
                 out_files.push_back(decoded_data);
 
+                // TODO: add audio_url support by reusing handle_media()
+
                 // replace this chunk with a marker
                 p["type"] = "text";
                 p["text"] = mtmd_default_marker();

+ 1 - 0
tools/server/server-common.h

@@ -284,6 +284,7 @@ struct oaicompat_parser_options {
     bool allow_image;
     bool allow_audio;
     bool enable_thinking = true;
+    std::string media_path;
 };
 
 // used by /chat/completions endpoint

+ 1 - 0
tools/server/server-context.cpp

@@ -788,6 +788,7 @@ struct server_context_impl {
             /* allow_image           */ mctx ? mtmd_support_vision(mctx) : false,
             /* allow_audio           */ mctx ? mtmd_support_audio (mctx) : false,
             /* enable_thinking       */ enable_thinking,
+            /* media_path            */ params_base.media_path,
         };
 
         // print sample chat example to make it clear which template is used

+ 2 - 0
tools/server/server.cpp

@@ -38,9 +38,11 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
         try {
             return func(req);
         } catch (const std::invalid_argument & e) {
+            // treat invalid_argument as invalid request (400)
             error = ERROR_TYPE_INVALID_REQUEST;
             message = e.what();
         } catch (const std::exception & e) {
+            // treat other exceptions as server error (500)
             error = ERROR_TYPE_SERVER;
             message = e.what();
         } catch (...) {

+ 31 - 0
tools/server/tests/unit/test_security.py

@@ -94,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
     assert res.status_code == 200
     assert cors_header in res.headers
     assert res.headers[cors_header] == cors_header_value
+
+
+@pytest.mark.parametrize(
+    "media_path, image_url, success",
+    [
+        (None,             "file://mtmd/test-1.jpeg",    False), # disabled media path, should fail
+        ("../../../tools", "file://mtmd/test-1.jpeg",    True),
+        ("../../../tools", "file:////mtmd//test-1.jpeg", True),  # should be the same file as above
+        ("../../../tools", "file://mtmd/notfound.jpeg",  False), # non-existent file
+        ("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
+    ]
+)
+def test_local_media_file(media_path, image_url, success,):
+    server = ServerPreset.tinygemma3()
+    server.media_path = media_path
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": 1,
+        "messages": [
+            {"role": "user", "content": [
+                {"type": "text", "text": "test"},
+                {"type": "image_url", "image_url": {
+                    "url": image_url,
+                }},
+            ]},
+        ],
+    })
+    if success:
+        assert res.status_code == 200
+    else:
+        assert res.status_code == 400

+ 3 - 0
tools/server/tests/utils.py

@@ -95,6 +95,7 @@ class ServerProcess:
     chat_template_file: str | None = None
     server_path: str | None = None
     mmproj_url: str | None = None
+    media_path: str | None = None
 
     # session variables
     process: subprocess.Popen | None = None
@@ -217,6 +218,8 @@ class ServerProcess:
             server_args.extend(["--chat-template-file", self.chat_template_file])
         if self.mmproj_url:
             server_args.extend(["--mmproj-url", self.mmproj_url])
+        if self.media_path:
+            server_args.extend(["--media-path", self.media_path])
 
         args = [str(arg) for arg in [server_path, *server_args]]
         print(f"tests: starting server with: {' '.join(args)}")