Explorar el Código

server : disable context shift by default (#15416)

* server : disable context shift by default

ggml-ci

* server : make scopr of test parameters local
Georgi Gerganov hace 5 meses
padre
commit
d2fcd91cf9

+ 7 - 0
common/arg.cpp

@@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.ctx_shift = false;
         }
     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
+    add_opt(common_arg(
+        {"--context-shift"},
+        string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
+        [](common_params & params) {
+            params.ctx_shift = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
     add_opt(common_arg(
         {"--chunks"}, "N",
         string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),

+ 1 - 1
common/common.h

@@ -375,7 +375,7 @@ struct common_params {
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
     bool flash_attn        = false; // flash attention
     bool no_perf           = false; // disable performance metrics
-    bool ctx_shift         = true;  // context shift on inifinite text generation
+    bool ctx_shift         = false;  // context shift on inifinite text generation
     bool swa_full          = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
     bool kv_unified        = false; // enable unified KV cache
 

+ 1 - 1
tools/server/tests/unit/test_basic.py

@@ -5,7 +5,7 @@ from utils import *
 server = ServerPreset.tinyllama2()
 
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()

+ 2 - 2
tools/server/tests/unit/test_completion.py

@@ -7,7 +7,7 @@ from utils import *
 server = ServerPreset.tinyllama2()
 
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()
@@ -229,7 +229,7 @@ def test_nocache_long_input_prompt():
         "temperature": 1.0,
         "cache_prompt": False,
     })
-    assert res.status_code == 200
+    assert res.status_code == 400
 
 
 def test_completion_with_tokens_input():

+ 2 - 4
tools/server/tests/unit/test_ctx_shift.py

@@ -11,7 +11,7 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu
 Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
 """.strip()
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()
@@ -25,6 +25,7 @@ def test_ctx_shift_enabled():
     # the prompt is truncated to keep the last 109 tokens
     # 64 tokens are generated thanks to shifting the context when it gets full
     global server
+    server.enable_ctx_shift = True
     server.start()
     res = server.make_request("POST", "/completion", data={
         "n_predict": 64,
@@ -42,7 +43,6 @@ def test_ctx_shift_enabled():
 ])
 def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
     global server
-    server.disable_ctx_shift = True
     server.n_predict = -1
     server.start()
     res = server.make_request("POST", "/completion", data={
@@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
 
 def test_ctx_shift_disabled_long_prompt():
     global server
-    server.disable_ctx_shift = True
     server.start()
     res = server.make_request("POST", "/completion", data={
         "n_predict": 64,
@@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt():
 
 def test_ctx_shift_disabled_stream():
     global server
-    server.disable_ctx_shift = True
     server.start()
     res = server.make_stream_request("POST", "/v1/completions", data={
         "n_predict": 256,

+ 1 - 1
tools/server/tests/unit/test_embedding.py

@@ -8,7 +8,7 @@ server = ServerPreset.bert_bge_small()
 
 EPSILON = 1e-3
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.bert_bge_small()

+ 1 - 1
tools/server/tests/unit/test_infill.py

@@ -3,7 +3,7 @@ from utils import *
 
 server = ServerPreset.tinyllama_infill()
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama_infill()

+ 1 - 1
tools/server/tests/unit/test_lora.py

@@ -5,7 +5,7 @@ server = ServerPreset.stories15m_moe()
 
 LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.stories15m_moe()

+ 1 - 1
tools/server/tests/unit/test_rerank.py

@@ -4,7 +4,7 @@ from utils import *
 server = ServerPreset.jina_reranker_tiny()
 
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.jina_reranker_tiny()

+ 1 - 1
tools/server/tests/unit/test_security.py

@@ -6,7 +6,7 @@ server = ServerPreset.tinyllama2()
 
 TEST_API_KEY = "sk-this-is-the-secret-key"
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()

+ 1 - 1
tools/server/tests/unit/test_slot_save.py

@@ -3,7 +3,7 @@ from utils import *
 
 server = ServerPreset.tinyllama2()
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()

+ 2 - 1
tools/server/tests/unit/test_speculative.py

@@ -16,7 +16,7 @@ def create_server():
     server.draft_max = 8
 
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def fixture_create_server():
     return create_server()
 
@@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded():
 def test_with_ctx_shift():
     global server
     server.n_ctx = 64
+    server.enable_ctx_shift = True
     server.start()
     res = server.make_request("POST", "/completion", data={
         "prompt": "Hello " * 56,

+ 1 - 1
tools/server/tests/unit/test_tokenize.py

@@ -4,7 +4,7 @@ from utils import *
 server = ServerPreset.tinyllama2()
 
 
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()

+ 2 - 0
tools/server/tests/unit/test_tool_call.py

@@ -22,6 +22,8 @@ def create_server():
     server.model_alias = "tinyllama-2-tool-call"
     server.server_port = 8081
     server.n_slots = 1
+    server.n_ctx = 8192
+    server.n_batch = 2048
 
 class CompletionMode(Enum):
     NORMAL = "normal"

+ 3 - 3
tools/server/tests/utils.py

@@ -79,7 +79,7 @@ class ServerProcess:
     draft: int | None = None
     api_key: str | None = None
     lora_files: List[str] | None = None
-    disable_ctx_shift: int | None = False
+    enable_ctx_shift: int | None = False
     draft_min: int | None = None
     draft_max: int | None = None
     no_webui: bool | None = None
@@ -178,8 +178,8 @@ class ServerProcess:
         if self.lora_files:
             for lora_file in self.lora_files:
                 server_args.extend(["--lora", lora_file])
-        if self.disable_ctx_shift:
-            server_args.extend(["--no-context-shift"])
+        if self.enable_ctx_shift:
+            server_args.append("--context-shift")
         if self.api_key:
             server_args.extend(["--api-key", self.api_key])
         if self.draft_max:

+ 0 - 1
tools/tts/tts.cpp

@@ -581,7 +581,6 @@ int main(int argc, char ** argv) {
 
     params.model = params.vocoder.model;
     params.embedding = true;
-    params.ctx_shift = false; // silence warning
     params.n_ubatch = params.n_batch;
 
     common_init_result llama_init_cts = common_init_from_params(params);