|
|
@@ -434,8 +434,8 @@ def test_context_size_exceeded_stream():
|
|
|
@pytest.mark.parametrize(
|
|
|
"n_batch,batch_count,reuse_cache",
|
|
|
[
|
|
|
- (64, 3, False),
|
|
|
- (64, 1, True),
|
|
|
+ (64, 4, False),
|
|
|
+ (64, 2, True),
|
|
|
]
|
|
|
)
|
|
|
def test_return_progress(n_batch, batch_count, reuse_cache):
|
|
|
@@ -462,10 +462,18 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
|
|
res = make_cmpl_request()
|
|
|
last_progress = None
|
|
|
total_batch_count = 0
|
|
|
+
|
|
|
for data in res:
|
|
|
cur_progress = data.get("prompt_progress", None)
|
|
|
if cur_progress is None:
|
|
|
continue
|
|
|
+ if total_batch_count == 0:
|
|
|
+ # first progress report must have n_cache == n_processed
|
|
|
+ assert cur_progress["total"] > 0
|
|
|
+ assert cur_progress["cache"] == cur_progress["processed"]
|
|
|
+ if reuse_cache:
|
|
|
+ # when reusing cache, we expect some cached tokens
|
|
|
+ assert cur_progress["cache"] > 0
|
|
|
if last_progress is not None:
|
|
|
assert cur_progress["total"] == last_progress["total"]
|
|
|
assert cur_progress["cache"] == last_progress["cache"]
|
|
|
@@ -473,6 +481,7 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
|
|
total_batch_count += 1
|
|
|
last_progress = cur_progress
|
|
|
|
|
|
+ # last progress should indicate completion (all tokens processed)
|
|
|
assert last_progress is not None
|
|
|
assert last_progress["total"] > 0
|
|
|
assert last_progress["processed"] == last_progress["total"]
|