You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.
import torch
from transformers import AutoModelForCausalLM , AutoTokenizer
# 指定模型ID
model_id = " Qwen/Qwen1.5-0.5B-Chat "
# 设置设备, 优先使用GPU
device = " cuda " if torch . cuda . is_available ( ) else " cpu "
print ( f " Using device: { device } " )
# 加载分词器
tokenizer = AutoTokenizer . from_pretrained ( model_id )
# 加载模型,并将其移动到指定设备
model = AutoModelForCausalLM . from_pretrained ( model_id ) . to ( device )
print ( " 模型和分词器加载完成! " )
# 准备对话输入
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : " 你好,请介绍你自己。 " }
]
# 使用分词器的模板格式化输入
text = tokenizer . apply_chat_template (
messages ,
tokenize = False ,
add_generation_prompt = True
)
# 编码输入文本
model_inputs = tokenizer ( [ text ] , return_tensors = " pt " ) . to ( device )
print ( " 编码后的输入文本: " )
print ( model_inputs )
# 使用模型生成回答
# max_new_tokens 控制了模型最多能生成多少个新的Token
generated_ids = model . generate (
model_inputs . input_ids ,
max_new_tokens = 512
)
# 将生成的 Token ID 截取掉输入部分
# 这样我们只解码模型新生成的部分
generated_ids = [
output_ids [ len ( input_ids ) : ] for input_ids , output_ids in zip ( model_inputs . input_ids , generated_ids )
]
# 解码生成的 Token ID
response = tokenizer . batch_decode ( generated_ids , skip_special_tokens = True ) [ 0 ]
print ( " \n 模型的回答: " )
print ( response )