You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
9.5 KiB
Python

6 months ago
import torch
import torch.nn as nn
import math
import copy
class MultiHeadAttention(nn.Module):
"""
多头注意力机制模块
"""
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 定义 Q, K, V 和输出的线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# 1. 计算注意力得分 (QK^T)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 2. 应用掩码 (如果提供)
if mask is not None:
# 将掩码中为 0 的位置设置为一个非常小的负数,这样 softmax 后会接近 0
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# 3. 计算注意力权重 (Softmax)
attn_probs = torch.softmax(attn_scores, dim=-1)
# 4. 加权求和 (权重 * V)
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
# 将输入 x 的形状从 (batch_size, seq_length, d_model)
# 变换为 (batch_size, num_heads, seq_length, d_k)
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
# 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k)
# 变回 (batch_size, seq_length, d_model)
batch_size, num_heads, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, mask=None):
# 1. 对 Q, K, V 进行线性变换
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# 2. 计算缩放点积注意力
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
# 3. 合并多头输出并进行最终的线性变换
output = self.W_o(self.combine_heads(attn_output))
return output
class PositionWiseFeedForward(nn.Module):
"""
位置前馈网络模块
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionWiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
# x 形状: (batch_size, seq_len, d_model)
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
# 最终输出形状: (batch_size, seq_len, d_model)
return x
class PositionalEncoding(nn.Module):
"""
为输入序列的词嵌入向量添加位置编码
"""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建一个足够长的位置编码矩阵
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# pe (positional encoding) 的大小为 (max_len, d_model)
pe = torch.zeros(max_len, d_model)
# 偶数维度使用 sin, 奇数维度使用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 将 pe 注册为 buffer这样它就不会被视为模型参数但会随模型移动例如 to(device)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.size(1) 是当前输入的序列长度
# 将位置编码加到输入向量上
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class EncoderLayer(nn.Module):
"""
编码器核心层
"""
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
# 1. 多头自注意力
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 2. 前馈网络
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class DecoderLayer(nn.Module):
"""
解码器核心层
"""
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask, tgt_mask):
# 1. 掩码多头自注意力 (对自己)
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# 2. 交叉注意力 (对编码器输出)
cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(cross_attn_output))
# 3. 前馈网络
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask):
x = self.embedding(x)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.embedding(x)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len=5000):
super(Transformer, self).__init__()
self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
self.final_linear = nn.Linear(d_model, tgt_vocab_size)
def generate_mask(self, src, tgt):
# src_mask: (batch_size, 1, 1, src_len)
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
# tgt_mask: (batch_size, 1, tgt_len, tgt_len)
tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len)
tgt_len = tgt.size(1)
# 下三角矩阵,用于防止看到未来的 token
tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool() # (tgt_len, tgt_len)
tgt_mask = tgt_pad_mask & tgt_sub_mask
return src_mask, tgt_mask
def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)
encoder_output = self.encoder(src, src_mask)
decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
output = self.final_linear(decoder_output)
return output
# --- 演示如何使用模型 ---
if __name__ == "__main__":
# 1. 定义超参数
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_len = 100
# 2. 实例化模型
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
# 3. 创建模拟输入数据
# 假设 batch_size=2, src_seq_len=10, tgt_seq_len=12
src = torch.randint(1, src_vocab_size, (2, 10)) # (batch_size, seq_length)
tgt = torch.randint(1, tgt_vocab_size, (2, 12)) # (batch_size, seq_length)
# 4. 模型前向传播
output = model(src, tgt)
# 5. 打印输出形状
print("模型输出的形状:", output.shape)
# 预期输出: torch.Size([2, 12, 5000]) -> (batch_size, tgt_seq_len, tgt_vocab_size)