|
@@ -193,7 +193,7 @@ print(f"Input text: {repr(prompt)}")
|
|
|
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
|
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
- outputs = model(input_ids)
|
|
|
|
|
|
|
+ outputs = model(input_ids.to(model.device))
|
|
|
logits = outputs.logits
|
|
logits = outputs.logits
|
|
|
|
|
|
|
|
# Extract logits for the last token (next token prediction)
|
|
# Extract logits for the last token (next token prediction)
|