import torch import torch.nn as nn import math import copy class MultiHeadAttention(nn.Module): """ 多头注意力机制模块 """ def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # 定义 Q, K, V 和输出的线性变换层 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def scaled_dot_product_attention(self, Q, K, V, mask=None): # 1. 计算注意力得分 (QK^T) attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # 2. 应用掩码 (如果提供) if mask is not None: # 将掩码中为 0 的位置设置为一个非常小的负数,这样 softmax 后会接近 0 attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # 3. 计算注意力权重 (Softmax) attn_probs = torch.softmax(attn_scores, dim=-1) # 4. 加权求和 (权重 * V) output = torch.matmul(attn_probs, V) return output def split_heads(self, x): # 将输入 x 的形状从 (batch_size, seq_length, d_model) # 变换为 (batch_size, num_heads, seq_length, d_k) batch_size, seq_length, d_model = x.size() return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) def combine_heads(self, x): # 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k) # 变回 (batch_size, seq_length, d_model) batch_size, num_heads, seq_length, d_k = x.size() return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) def forward(self, Q, K, V, mask=None): # 1. 对 Q, K, V 进行线性变换 Q = self.split_heads(self.W_q(Q)) K = self.split_heads(self.W_k(K)) V = self.split_heads(self.W_v(V)) # 2. 计算缩放点积注意力 attn_output = self.scaled_dot_product_attention(Q, K, V, mask) # 3. 合并多头输出并进行最终的线性变换 output = self.W_o(self.combine_heads(attn_output)) return output class PositionWiseFeedForward(nn.Module): """ 位置前馈网络模块 """ def __init__(self, d_model, d_ff, dropout=0.1): super(PositionWiseFeedForward, self).__init__() self.linear1 = nn.Linear(d_model, d_ff) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ff, d_model) self.relu = nn.ReLU() def forward(self, x): # x 形状: (batch_size, seq_len, d_model) x = self.linear1(x) x = self.relu(x) x = self.dropout(x) x = self.linear2(x) # 最终输出形状: (batch_size, seq_len, d_model) return x class PositionalEncoding(nn.Module): """ 为输入序列的词嵌入向量添加位置编码。 """ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) # 创建一个足够长的位置编码矩阵 position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) # pe (positional encoding) 的大小为 (max_len, d_model) pe = torch.zeros(max_len, d_model) # 偶数维度使用 sin, 奇数维度使用 cos pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) # 将 pe 注册为 buffer,这样它就不会被视为模型参数,但会随模型移动(例如 to(device)) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x.size(1) 是当前输入的序列长度 # 将位置编码加到输入向量上 x = x + self.pe[:, :x.size(1)] return self.dropout(x) class EncoderLayer(nn.Module): """ 编码器核心层 """ def __init__(self, d_model, num_heads, d_ff, dropout): super(EncoderLayer, self).__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): # 1. 多头自注意力 attn_output = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # 2. 前馈网络 ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x class DecoderLayer(nn.Module): """ 解码器核心层 """ def __init__(self, d_model, num_heads, d_ff, dropout): super(DecoderLayer, self).__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, encoder_output, src_mask, tgt_mask): # 1. 掩码多头自注意力 (对自己) attn_output = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_output)) # 2. 交叉注意力 (对编码器输出) cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask) x = self.norm2(x + self.dropout(cross_attn_output)) # 3. 前馈网络 ff_output = self.feed_forward(x) x = self.norm3(x + self.dropout(ff_output)) return x class Encoder(nn.Module): def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len): super(Encoder, self).__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model, dropout, max_len) self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.norm = nn.LayerNorm(d_model) def forward(self, x, mask): x = self.embedding(x) x = self.pos_encoder(x) for layer in self.layers: x = layer(x, mask) return self.norm(x) class Decoder(nn.Module): def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len): super(Decoder, self).__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model, dropout, max_len) self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.norm = nn.LayerNorm(d_model) def forward(self, x, encoder_output, src_mask, tgt_mask): x = self.embedding(x) x = self.pos_encoder(x) for layer in self.layers: x = layer(x, encoder_output, src_mask, tgt_mask) return self.norm(x) class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len=5000): super(Transformer, self).__init__() self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len) self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len) self.final_linear = nn.Linear(d_model, tgt_vocab_size) def generate_mask(self, src, tgt): # src_mask: (batch_size, 1, 1, src_len) src_mask = (src != 0).unsqueeze(1).unsqueeze(2) # tgt_mask: (batch_size, 1, tgt_len, tgt_len) tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len) tgt_len = tgt.size(1) # 下三角矩阵,用于防止看到未来的 token tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool() # (tgt_len, tgt_len) tgt_mask = tgt_pad_mask & tgt_sub_mask return src_mask, tgt_mask def forward(self, src, tgt): src_mask, tgt_mask = self.generate_mask(src, tgt) encoder_output = self.encoder(src, src_mask) decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask) output = self.final_linear(decoder_output) return output # --- 演示如何使用模型 --- if __name__ == "__main__": # 1. 定义超参数 src_vocab_size = 5000 tgt_vocab_size = 5000 d_model = 512 num_layers = 6 num_heads = 8 d_ff = 2048 dropout = 0.1 max_len = 100 # 2. 实例化模型 model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len) # 3. 创建模拟输入数据 # 假设 batch_size=2, src_seq_len=10, tgt_seq_len=12 src = torch.randint(1, src_vocab_size, (2, 10)) # (batch_size, seq_length) tgt = torch.randint(1, tgt_vocab_size, (2, 12)) # (batch_size, seq_length) # 4. 模型前向传播 output = model(src, tgt) # 5. 打印输出形状 print("模型输出的形状:", output.shape) # 预期输出: torch.Size([2, 12, 5000]) -> (batch_size, tgt_seq_len, tgt_vocab_size)