test_ctx_shift.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.tinyllama2()
  4. SHORT_TEXT = """
  5. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
  6. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
  7. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
  8. """.strip()
  9. LONG_TEXT = """
  10. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
  11. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
  12. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
  13. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
  14. """.strip()
  15. @pytest.fixture(autouse=True)
  16. def create_server():
  17. global server
  18. server = ServerPreset.tinyllama2()
  19. server.n_ctx = 512
  20. server.n_slots = 2
  21. server.n_predict = 128
  22. def test_ctx_shift_enabled():
  23. # the prompt is 226 tokens
  24. # the slot context is 512/2 = 256 tokens
  25. # 96 tokens are generated thanks to shifting the context when it gets full
  26. global server
  27. server.enable_ctx_shift = True
  28. server.start()
  29. res = server.make_request("POST", "/completion", data={
  30. "n_predict": 96,
  31. "prompt": SHORT_TEXT,
  32. })
  33. assert res.status_code == 200
  34. assert res.body["timings"]["prompt_n"] == 226
  35. assert res.body["timings"]["predicted_n"] == 96
  36. assert res.body["truncated"] is True
  37. @pytest.mark.parametrize("n_predict,n_token_output,truncated", [
  38. (64, 64, False),
  39. (-1, 120, True),
  40. ])
  41. def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
  42. global server
  43. server.n_predict = -1
  44. server.start()
  45. res = server.make_request("POST", "/completion", data={
  46. "n_predict": n_predict,
  47. "prompt": "Hi how are you",
  48. })
  49. assert res.status_code == 200
  50. assert res.body["timings"]["predicted_n"] == n_token_output
  51. assert res.body["truncated"] == truncated
  52. def test_ctx_shift_disabled_long_prompt():
  53. global server
  54. server.start()
  55. res = server.make_request("POST", "/completion", data={
  56. "n_predict": 64,
  57. "prompt": LONG_TEXT,
  58. })
  59. assert res.status_code != 200
  60. assert "error" in res.body
  61. assert "exceeds the available context size" in res.body["error"]["message"]
  62. def test_ctx_shift_disabled_stream():
  63. global server
  64. server.start()
  65. res = server.make_stream_request("POST", "/v1/completions", data={
  66. "n_predict": 256,
  67. "prompt": "Once",
  68. "stream": True,
  69. })
  70. content = ""
  71. for data in res:
  72. choice = data["choices"][0]
  73. if choice["finish_reason"] == "length":
  74. assert len(content) > 0
  75. else:
  76. assert choice["finish_reason"] is None
  77. content += choice["text"]