Signed-off-by: Jie Fu <jiefu@tencent.com>
@@ -193,7 +193,7 @@ print(f"Input text: {repr(prompt)}")
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
with torch.no_grad():
- outputs = model(input_ids)
+ outputs = model(input_ids.to(model.device))
logits = outputs.logits
# Extract logits for the last token (next token prediction)