|
@@ -11,7 +11,7 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu
|
|
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
|
|
""".strip()
|
|
""".strip()
|
|
|
|
|
|
|
|
-@pytest.fixture(scope="module", autouse=True)
|
|
|
|
|
|
|
+@pytest.fixture(autouse=True)
|
|
|
def create_server():
|
|
def create_server():
|
|
|
global server
|
|
global server
|
|
|
server = ServerPreset.tinyllama2()
|
|
server = ServerPreset.tinyllama2()
|
|
@@ -25,6 +25,7 @@ def test_ctx_shift_enabled():
|
|
|
# the prompt is truncated to keep the last 109 tokens
|
|
# the prompt is truncated to keep the last 109 tokens
|
|
|
# 64 tokens are generated thanks to shifting the context when it gets full
|
|
# 64 tokens are generated thanks to shifting the context when it gets full
|
|
|
global server
|
|
global server
|
|
|
|
|
+ server.enable_ctx_shift = True
|
|
|
server.start()
|
|
server.start()
|
|
|
res = server.make_request("POST", "/completion", data={
|
|
res = server.make_request("POST", "/completion", data={
|
|
|
"n_predict": 64,
|
|
"n_predict": 64,
|
|
@@ -42,7 +43,6 @@ def test_ctx_shift_enabled():
|
|
|
])
|
|
])
|
|
|
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
|
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
|
|
global server
|
|
global server
|
|
|
- server.disable_ctx_shift = True
|
|
|
|
|
server.n_predict = -1
|
|
server.n_predict = -1
|
|
|
server.start()
|
|
server.start()
|
|
|
res = server.make_request("POST", "/completion", data={
|
|
res = server.make_request("POST", "/completion", data={
|
|
@@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
|
|
|
|
|
|
|
|
def test_ctx_shift_disabled_long_prompt():
|
|
def test_ctx_shift_disabled_long_prompt():
|
|
|
global server
|
|
global server
|
|
|
- server.disable_ctx_shift = True
|
|
|
|
|
server.start()
|
|
server.start()
|
|
|
res = server.make_request("POST", "/completion", data={
|
|
res = server.make_request("POST", "/completion", data={
|
|
|
"n_predict": 64,
|
|
"n_predict": 64,
|
|
@@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt():
|
|
|
|
|
|
|
|
def test_ctx_shift_disabled_stream():
|
|
def test_ctx_shift_disabled_stream():
|
|
|
global server
|
|
global server
|
|
|
- server.disable_ctx_shift = True
|
|
|
|
|
server.start()
|
|
server.start()
|
|
|
res = server.make_stream_request("POST", "/v1/completions", data={
|
|
res = server.make_stream_request("POST", "/v1/completions", data={
|
|
|
"n_predict": 256,
|
|
"n_predict": 256,
|