1
0

test_rerank.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.jina_reranker_tiny()
  4. @pytest.fixture(autouse=True)
  5. def create_server():
  6. global server
  7. server = ServerPreset.jina_reranker_tiny()
  8. TEST_DOCUMENTS = [
  9. "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
  10. "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
  11. "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
  12. "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
  13. ]
  14. def test_rerank():
  15. global server
  16. server.start()
  17. res = server.make_request("POST", "/rerank", data={
  18. "query": "Machine learning is",
  19. "documents": TEST_DOCUMENTS,
  20. })
  21. assert res.status_code == 200
  22. assert len(res.body["results"]) == 4
  23. most_relevant = res.body["results"][0]
  24. least_relevant = res.body["results"][0]
  25. for doc in res.body["results"]:
  26. if doc["relevance_score"] > most_relevant["relevance_score"]:
  27. most_relevant = doc
  28. if doc["relevance_score"] < least_relevant["relevance_score"]:
  29. least_relevant = doc
  30. assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
  31. assert most_relevant["index"] == 2
  32. assert least_relevant["index"] == 3
  33. def test_rerank_tei_format():
  34. global server
  35. server.start()
  36. res = server.make_request("POST", "/rerank", data={
  37. "query": "Machine learning is",
  38. "texts": TEST_DOCUMENTS,
  39. })
  40. assert res.status_code == 200
  41. assert len(res.body) == 4
  42. most_relevant = res.body[0]
  43. least_relevant = res.body[0]
  44. for doc in res.body:
  45. if doc["score"] > most_relevant["score"]:
  46. most_relevant = doc
  47. if doc["score"] < least_relevant["score"]:
  48. least_relevant = doc
  49. assert most_relevant["score"] > least_relevant["score"]
  50. assert most_relevant["index"] == 2
  51. assert least_relevant["index"] == 3
  52. @pytest.mark.parametrize("documents", [
  53. [],
  54. None,
  55. 123,
  56. [1, 2, 3],
  57. ])
  58. def test_invalid_rerank_req(documents):
  59. global server
  60. server.start()
  61. res = server.make_request("POST", "/rerank", data={
  62. "query": "Machine learning is",
  63. "documents": documents,
  64. })
  65. assert res.status_code == 400
  66. assert "error" in res.body
  67. @pytest.mark.parametrize(
  68. "query,doc1,doc2,n_tokens",
  69. [
  70. ("Machine learning is", "A machine", "Learning is", 19),
  71. ("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
  72. ]
  73. )
  74. def test_rerank_usage(query, doc1, doc2, n_tokens):
  75. global server
  76. server.start()
  77. res = server.make_request("POST", "/rerank", data={
  78. "query": query,
  79. "documents": [
  80. doc1,
  81. doc2,
  82. ]
  83. })
  84. assert res.status_code == 200
  85. assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
  86. assert res.body['usage']['prompt_tokens'] == n_tokens
  87. @pytest.mark.parametrize("top_n,expected_len", [
  88. (None, len(TEST_DOCUMENTS)), # no top_n parameter
  89. (2, 2),
  90. (4, 4),
  91. (99, len(TEST_DOCUMENTS)), # higher than available docs
  92. ])
  93. def test_rerank_top_n(top_n, expected_len):
  94. global server
  95. server.start()
  96. data = {
  97. "query": "Machine learning is",
  98. "documents": TEST_DOCUMENTS,
  99. }
  100. if top_n is not None:
  101. data["top_n"] = top_n
  102. res = server.make_request("POST", "/rerank", data=data)
  103. assert res.status_code == 200
  104. assert len(res.body["results"]) == expected_len
  105. @pytest.mark.parametrize("top_n,expected_len", [
  106. (None, len(TEST_DOCUMENTS)), # no top_n parameter
  107. (2, 2),
  108. (4, 4),
  109. (99, len(TEST_DOCUMENTS)), # higher than available docs
  110. ])
  111. def test_rerank_tei_top_n(top_n, expected_len):
  112. global server
  113. server.start()
  114. data = {
  115. "query": "Machine learning is",
  116. "texts": TEST_DOCUMENTS,
  117. }
  118. if top_n is not None:
  119. data["top_n"] = top_n
  120. res = server.make_request("POST", "/rerank", data=data)
  121. assert res.status_code == 200
  122. assert len(res.body) == expected_len