|
|
@@ -1,6 +1,8 @@
|
|
|
import pytest
|
|
|
import requests
|
|
|
import time
|
|
|
+import random
|
|
|
+
|
|
|
from openai import OpenAI
|
|
|
from utils import *
|
|
|
|
|
|
@@ -564,3 +566,43 @@ def test_cancel_request():
|
|
|
time.sleep(1) # wait for HTTP_POLLING_SECONDS
|
|
|
res = server.make_request("GET", "/slots")
|
|
|
assert res.body[0]["is_processing"] == False
|
|
|
+
|
|
|
+
|
|
|
+# this test exercises the host-memory prompt cache
|
|
|
+# ref: https://github.com/ggml-org/llama.cpp/pull/16391
|
|
|
+# ref: https://github.com/ggml-org/llama.cpp/pull/17078
|
|
|
+def test_completion_prompt_cache():
|
|
|
+ global server
|
|
|
+ server.n_slots = 2
|
|
|
+ server.kv_unified = True
|
|
|
+ server.start()
|
|
|
+
|
|
|
+ for _ in range(16):
|
|
|
+ # generate alternating random prompts with variable lengths in order to get them in and out of the cache
|
|
|
+ r = random.randint(0, 4)
|
|
|
+ prompt = (" Hello " + str(r)) * (40 + r)
|
|
|
+ n_prompt = (40 + r)*5 + 2
|
|
|
+ n_predict = random.randint(1, 8)
|
|
|
+
|
|
|
+ res = server.make_request(
|
|
|
+ "POST",
|
|
|
+ "/completion",
|
|
|
+ data={
|
|
|
+ "prompt": prompt,
|
|
|
+ "n_predict": n_predict,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ assert res.status_code == 200
|
|
|
+ assert "content" in res.body
|
|
|
+ content = res.body["content"]
|
|
|
+ assert isinstance(content, str)
|
|
|
+ assert len(content) > 0
|
|
|
+
|
|
|
+ assert type(res.body["has_new_line"]) == bool
|
|
|
+ assert "timings" in res.body
|
|
|
+ timings = res.body["timings"]
|
|
|
+
|
|
|
+ assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
|
|
|
+ assert "predicted_n" in timings and timings["predicted_n"] == n_predict
|
|
|
+ assert "tokens" in res.body and isinstance(res.body["tokens"], list)
|