You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
9.5 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn as nn
import math
import copy
class MultiHeadAttention(nn.Module):
"""
多头注意力机制模块
"""
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 定义 Q, K, V 和输出的线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# 1. 计算注意力得分 (QK^T)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 2. 应用掩码 (如果提供)
if mask is not None:
# 将掩码中为 0 的位置设置为一个非常小的负数,这样 softmax 后会接近 0
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# 3. 计算注意力权重 (Softmax)
attn_probs = torch.softmax(attn_scores, dim=-1)
# 4. 加权求和 (权重 * V)
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
# 将输入 x 的形状从 (batch_size, seq_length, d_model)
# 变换为 (batch_size, num_heads, seq_length, d_k)
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
# 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k)
# 变回 (batch_size, seq_length, d_model)
batch_size, num_heads, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, mask=None):
# 1. 对 Q, K, V 进行线性变换
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# 2. 计算缩放点积注意力
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
# 3. 合并多头输出并进行最终的线性变换
output = self.W_o(self.combine_heads(attn_output))
return output
class PositionWiseFeedForward(nn.Module):
"""
位置前馈网络模块
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionWiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
# x 形状: (batch_size, seq_len, d_model)
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
# 最终输出形状: (batch_size, seq_len, d_model)
return x
class PositionalEncoding(nn.Module):
"""
为输入序列的词嵌入向量添加位置编码。
"""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建一个足够长的位置编码矩阵
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# pe (positional encoding) 的大小为 (max_len, d_model)
pe = torch.zeros(max_len, d_model)
# 偶数维度使用 sin, 奇数维度使用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 将 pe 注册为 buffer这样它就不会被视为模型参数但会随模型移动例如 to(device)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.size(1) 是当前输入的序列长度
# 将位置编码加到输入向量上
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class EncoderLayer(nn.Module):
"""
编码器核心层
"""
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
# 1. 多头自注意力
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 2. 前馈网络
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class DecoderLayer(nn.Module):
"""
解码器核心层
"""
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask, tgt_mask):
# 1. 掩码多头自注意力 (对自己)
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# 2. 交叉注意力 (对编码器输出)
cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(cross_attn_output))
# 3. 前馈网络
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask):
x = self.embedding(x)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.embedding(x)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len=5000):
super(Transformer, self).__init__()
self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
self.final_linear = nn.Linear(d_model, tgt_vocab_size)
def generate_mask(self, src, tgt):
# src_mask: (batch_size, 1, 1, src_len)
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
# tgt_mask: (batch_size, 1, tgt_len, tgt_len)
tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len)
tgt_len = tgt.size(1)
# 下三角矩阵,用于防止看到未来的 token
tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool() # (tgt_len, tgt_len)
tgt_mask = tgt_pad_mask & tgt_sub_mask
return src_mask, tgt_mask
def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)
encoder_output = self.encoder(src, src_mask)
decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
output = self.final_linear(decoder_output)
return output
# --- 演示如何使用模型 ---
if __name__ == "__main__":
# 1. 定义超参数
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_len = 100
# 2. 实例化模型
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
# 3. 创建模拟输入数据
# 假设 batch_size=2, src_seq_len=10, tgt_seq_len=12
src = torch.randint(1, src_vocab_size, (2, 10)) # (batch_size, seq_length)
tgt = torch.randint(1, tgt_vocab_size, (2, 12)) # (batch_size, seq_length)
# 4. 模型前向传播
output = model(src, tgt)
# 5. 打印输出形状
print("模型输出的形状:", output.shape)
# 预期输出: torch.Size([2, 12, 5000]) -> (batch_size, tgt_seq_len, tgt_vocab_size)