> 技术文档 > 【第四章:大模型(LLM)】01.神经网络中的 NLP-(2)Seq2Seq 原理及代码解析

【第四章:大模型(LLM)】01.神经网络中的 NLP-(2)Seq2Seq 原理及代码解析


第四章:大模型(LLM)

第二部分:神经网络中的 NLP

第二节:Seq2Seq 原理及代码解析

1. Seq2Seq(Sequence-to-Sequence)模型原理

Seq2Seq 是一种处理序列到序列任务(如机器翻译、文本摘要、对话生成等)的深度学习架构,最早由 Google 在 2014 年提出。其核心思想是使用 编码器(Encoder) 将输入序列编码为上下文向量,再通过 解码器(Decoder) 逐步生成输出序列。

1.1 架构组成

  1. 编码器(Encoder)

    • 通常是 RNN、LSTM 或 GRU。

    • 输入:序列 x = (x_1, x_2, ..., x_T)

    • 输出:隐藏状态 h_T​,作为上下文向量。

  2. 解码器(Decoder)

    • 结构类似于编码器。

    • 输入:编码器输出的上下文向量 + 上一步预测的输出。

    • 输出:目标序列 y = (y_1, y_2, ..., y_T)

  3. 上下文向量(Context Vector)

    • 编码器最后一个隐藏状态 h_T​ 作为整个输入序列的信息摘要。


2. 数学公式

  • 编码器:

h_t = f(h_{t-1}, x_t)

  • 解码器:

s_t = f(s_{t-1}, y_{t-1}, c)
<img alt=\"P(y_t|y_{

其中 c 是上下文向量。


3. 经典 Seq2Seq 训练流程

  1. 输入序列通过编码器,生成上下文向量。

  2. 解码器利用上下文向量和前一时刻的预测结果,逐步生成输出。

  3. 使用 教师强制(Teacher Forcing) 技术,训练时将真实标签输入解码器。


4. 改进:Attention 机制

Seq2Seq 传统模型存在 长序列信息丢失 问题。
Attention 通过在每一步解码时为输入序列不同部分分配权重,解决了这个问题。
公式:

c_t = \\sum_{i=1}^{T_x} \\alpha_{t,i} h_i

其中 \\alpha_{t,i}​ 是注意力权重。


5. PyTorch 代码解析:Seq2Seq 示例

import torchimport torch.nn as nnimport torch.optim as optim# Encoderclass Encoder(nn.Module): def __init__(self, input_dim, hidden_dim, num_layers=1): super(Encoder, self).__init__() self.rnn = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True) def forward(self, x): outputs, hidden = self.rnn(x) return hidden# Decoderclass Decoder(nn.Module): def __init__(self, output_dim, hidden_dim, num_layers=1): super(Decoder, self).__init__() self.rnn = nn.GRU(output_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x, hidden): output, hidden = self.rnn(x, hidden) pred = self.fc(output) return pred, hidden# Seq2Seqclass Seq2Seq(nn.Module): def __init__(self, encoder, decoder): super(Seq2Seq, self).__init__() self.encoder = encoder self.decoder = decoder def forward(self, src, trg): hidden = self.encoder(src) outputs, _ = self.decoder(trg, hidden) return outputs# Example usageinput_dim, output_dim, hidden_dim = 10, 10, 32encoder = Encoder(input_dim, hidden_dim)decoder = Decoder(output_dim, hidden_dim)model = Seq2Seq(encoder, decoder)src = torch.randn(16, 20, input_dim) # batch=16, seq_len=20trg = torch.randn(16, 20, output_dim)output = model(src, trg)print(output.shape) # [16, 20, 10]

6. 应用场景

  • 机器翻译(Google Translate)

  • 文本摘要(新闻摘要生成)

  • 对话系统(聊天机器人)

  • 语音识别(语音到文本)