1
0

test_slot_save.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.tinyllama2()
  4. @pytest.fixture(autouse=True)
  5. def create_server():
  6. global server
  7. server = ServerPreset.tinyllama2()
  8. server.slot_save_path = "./tmp"
  9. server.temperature = 0.0
  10. def test_slot_save_restore():
  11. global server
  12. server.start()
  13. # First prompt in slot 1 should be fully processed
  14. res = server.make_request("POST", "/completion", data={
  15. "prompt": "What is the capital of France?",
  16. "id_slot": 1,
  17. "cache_prompt": True,
  18. })
  19. assert res.status_code == 200
  20. assert match_regex("(Whiskers|Flana)+", res.body["content"])
  21. assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
  22. # Save state of slot 1
  23. res = server.make_request("POST", "/slots/1?action=save", data={
  24. "filename": "slot1.bin",
  25. })
  26. assert res.status_code == 200
  27. assert res.body["n_saved"] == 84
  28. # Since we have cache, this should only process the last tokens
  29. res = server.make_request("POST", "/completion", data={
  30. "prompt": "What is the capital of Germany?",
  31. "id_slot": 1,
  32. "cache_prompt": True,
  33. })
  34. assert res.status_code == 200
  35. assert match_regex("(Jack|said)+", res.body["content"])
  36. assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
  37. # Loading the saved cache into slot 0
  38. res = server.make_request("POST", "/slots/0?action=restore", data={
  39. "filename": "slot1.bin",
  40. })
  41. assert res.status_code == 200
  42. assert res.body["n_restored"] == 84
  43. # Since we have cache, slot 0 should only process the last tokens
  44. res = server.make_request("POST", "/completion", data={
  45. "prompt": "What is the capital of Germany?",
  46. "id_slot": 0,
  47. "cache_prompt": True,
  48. })
  49. assert res.status_code == 200
  50. assert match_regex("(Jack|said)+", res.body["content"])
  51. assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
  52. # For verification that slot 1 was not corrupted during slot 0 load, same thing should work
  53. res = server.make_request("POST", "/completion", data={
  54. "prompt": "What is the capital of Germany?",
  55. "id_slot": 1,
  56. "cache_prompt": True,
  57. })
  58. assert res.status_code == 200
  59. assert match_regex("(Jack|said)+", res.body["content"])
  60. assert res.body["timings"]["prompt_n"] == 1
  61. def test_slot_erase():
  62. global server
  63. server.start()
  64. res = server.make_request("POST", "/completion", data={
  65. "prompt": "What is the capital of France?",
  66. "id_slot": 1,
  67. "cache_prompt": True,
  68. })
  69. assert res.status_code == 200
  70. assert match_regex("(Whiskers|Flana)+", res.body["content"])
  71. assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
  72. # erase slot 1
  73. res = server.make_request("POST", "/slots/1?action=erase")
  74. assert res.status_code == 200
  75. # re-run the same prompt, it should process all tokens again
  76. res = server.make_request("POST", "/completion", data={
  77. "prompt": "What is the capital of France?",
  78. "id_slot": 1,
  79. "cache_prompt": True,
  80. })
  81. assert res.status_code == 200
  82. assert match_regex("(Whiskers|Flana)+", res.body["content"])
  83. assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed