Просмотр исходного кода

server : add flag to disable the web-ui (#10762) (#10751)

Co-authored-by: eugenio.segala <esegala@deloitte.co.uk>
Yüg 1 год назад
Родитель
Сommit
a86ad841f1

+ 7 - 0
common/arg.cpp

@@ -1711,6 +1711,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.public_path = value;
             params.public_path = value;
         }
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
+    add_opt(common_arg(
+        {"--no-webui"},
+        string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
+        [](common_params & params) {
+            params.webui = false;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_WEBUI"));
     add_opt(common_arg(
     add_opt(common_arg(
         {"--embedding", "--embeddings"},
         {"--embedding", "--embeddings"},
         string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
         string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),

+ 1 - 0
examples/server/README.md

@@ -146,6 +146,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--host HOST` | ip address to listen (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
 | `--host HOST` | ip address to listen (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
 | `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
 | `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
 | `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
 | `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
+| `--no-webui` | disable the Web UI<br/>(env: LLAMA_ARG_NO_WEBUI) |
 | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
 | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
 | `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
 | `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
 | `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
 | `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |

+ 17 - 13
examples/server/server.cpp

@@ -3815,20 +3815,24 @@ int main(int argc, char ** argv) {
     // Router
     // Router
     //
     //
 
 
-    // register static assets routes
-    if (!params.public_path.empty()) {
-        // Set the base directory for serving static files
-        bool is_found = svr->set_mount_point("/", params.public_path);
-        if (!is_found) {
-            LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
-            return 1;
-        }
+    if (!params.webui) {
+        LOG_INF("Web UI is disabled\n");
     } else {
     } else {
-        // using embedded static index.html
-        svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
-            res.set_content(reinterpret_cast<const char*>(index_html), index_html_len, "text/html; charset=utf-8");
-            return false;
-        });
+        // register static assets routes
+        if (!params.public_path.empty()) {
+            // Set the base directory for serving static files
+            bool is_found = svr->set_mount_point("/", params.public_path);
+            if (!is_found) {
+                LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
+                return 1;
+            }
+        } else {
+            // using embedded static index.html
+            svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
+                res.set_content(reinterpret_cast<const char*>(index_html), index_html_len, "text/html; charset=utf-8");
+                return false;
+            });
+        }
     }
     }
 
 
     // register API routes
     // register API routes

+ 18 - 0
examples/server/tests/unit/test_basic.py

@@ -1,4 +1,5 @@
 import pytest
 import pytest
+import requests
 from utils import *
 from utils import *
 
 
 server = ServerPreset.tinyllama2()
 server = ServerPreset.tinyllama2()
@@ -76,3 +77,20 @@ def test_load_split_model():
     })
     })
     assert res.status_code == 200
     assert res.status_code == 200
     assert match_regex("(little|girl)+", res.body["content"])
     assert match_regex("(little|girl)+", res.body["content"])
+
+
+def test_no_webui():
+    global server
+    # default: webui enabled
+    server.start()
+    url = f"http://{server.server_host}:{server.server_port}"
+    res = requests.get(url)
+    assert res.status_code == 200
+    assert "<html>" in res.text
+    server.stop()
+
+    # with --no-webui
+    server.no_webui = True
+    server.start()
+    res = requests.get(url)
+    assert res.status_code == 404

+ 3 - 0
examples/server/tests/utils.py

@@ -72,6 +72,7 @@ class ServerProcess:
     disable_ctx_shift: int | None = False
     disable_ctx_shift: int | None = False
     draft_min: int | None = None
     draft_min: int | None = None
     draft_max: int | None = None
     draft_max: int | None = None
+    no_webui: bool | None = None
 
 
     # session variables
     # session variables
     process: subprocess.Popen | None = None
     process: subprocess.Popen | None = None
@@ -158,6 +159,8 @@ class ServerProcess:
             server_args.extend(["--draft-max", self.draft_max])
             server_args.extend(["--draft-max", self.draft_max])
         if self.draft_min:
         if self.draft_min:
             server_args.extend(["--draft-min", self.draft_min])
             server_args.extend(["--draft-min", self.draft_min])
+        if self.no_webui:
+            server_args.append("--no-webui")
 
 
         args = [str(arg) for arg in [server_path, *server_args]]
         args = [str(arg) for arg in [server_path, *server_args]]
         print(f"bench: starting server with: {' '.join(args)}")
         print(f"bench: starting server with: {' '.join(args)}")