> 技术文档 > Task03【datawhale组队学习】JoyRL&EasyRL

Task03【datawhale组队学习】JoyRL&EasyRL


文章目录

  • 深度学习基础
  • 强化学习与深度学习的关系
  • 线性回归
  • 梯度下降
      • Adam自适应优化算法
  • 逻辑回归(解决二分类问题)
  • 全连接网络
  • 更高级的神经网络
    • 卷积神经网络(适用于处理网格结构的数据)
    • 循环神经网络(适用于处理序列数据)

教程:https://datawhalechina.github.io/joyrl-book/#/ch6/main
随机梯度下降(stochastic gradiant descent,SGD)
小批量梯度下降(mini-batch gradiant descent)
小批量随机梯度下降(mini-batch stochastic gradiant descent)
反向传播算法(Backpropagation Algorithm)
全连接网络(fully connected network),也称作多层感知机(multi-layer perceptron,MLP)
卷积神经网络(convolutional neural network,CNN)
循环神经网络(recurrent neural network,RNN)

深度学习基础

强化学习与深度学习的关系

  • 预测的主要目的是根据环境的状态和动作来预测状态价值和动作价值
  • 控制的主要目的是根据状态价值和动作价值来选择动作。
    预测主要是告诉我们当前状态下采取什么动作比较好,而控制则是按照某种方式决策。

如图所示,其中有一队叫做英国三狮,主要领队是尼尔逊和巴菲斯,巴菲斯是一个超级数据分析专家,他能在各种场景下计算对手传球、射门的概率,也包括我方进球和传球的各种收益,然后尼尔逊会根据他的数据分析结果来决定下一步行动。
尼尔逊也是一个非常有头脑的领队,他不会只依靠巴菲斯的计算结果,而是会结合自身的经验和对足球的直觉来做出数据之外的决策。这个数据之外的决策在强化学习中叫做探索
也就是说尼尔逊会根据巴菲斯的计算结果来做出决策,但是他也会根据自己的经验和直觉来做出一些不确定的决策,这样才能保证他的队伍不会被对手轻易的猜到。

Task03【datawhale组队学习】JoyRL&EasyRL
以上就是预测和控制的关系,通常在强化学习中预测和控制的部分看起来是共用一个 Q表或者神经网络的
因此读者们可能会因为主要关注价值函数的估计而忽视掉控制这层关系,控制通常在采样动作的过程中体现出来。

预测和控制的依赖:预测也相当于人的眼睛和大脑的视觉神经处理部分,而控制相当于大脑的决策神经处理部分,看似是两个独立的部分,但实际上是相互依赖的,预测的结果会影响到控制的决策,而控制的决策也会影响到预测的结果。

深度学习可以用来提高强化学习的预测效果和提高控制问题的性能
深度学习本身就是一个目前预测和分类效果俱佳的工具。比如 Q-learning的Q表就完全可以用神经网络来拟合。
深度学习只是一种非常广泛的应用,但并不是强化学习的必要条件,也可以是一些传统的预测模型,例如决策树、贝叶斯模型等等,因此读者在研究相关问题时需要充分打开思路。类似地,在控制问题中,也可以利用深度学习或者其他的方法来提高性能,例如结合进化算法来提高强化学习的探索能力。

训练模式上看,
训练:深度学习和强化学习,尤其是结合了深度学习的深度强化学习,都是基于大量的样本来对相应算法进行迭代更新并且达到最优的,这个过程我们称之为训练。
强化学习是在交互中产生样本的,是一个产生样本、算法更新、再次产生样本、再次算法更新的动态循环训练过程,而不是一个准备样本、算法更新的静态训练过程。

深度学习:主要解决 “打标签” 问题,分监督(人工打标签,如判断猫狗图 )、无监督(算法自动打标签 ),标签通常是离散值(如猫、狗分类 ) 。
强化学习:核心解决 “序列决策” 问题,可类比 “打分数”,判断状态好坏(分数体现 );虽也能解决类似 “打标签” 问题,但标签是连续值(如图片美观程度的连续评分 ) 。

如图所示,除了训练生成模型之外,强化学习相当于在深度学习的基础上增加了一条回路,即继续与环境交互产生样本。
相信学过控制系统的读者很快会意识到,这个回路就是一个典型的反馈系统机制,模型的输出一开始并不能达到预期的值,因此通过动态地不断与环境交互来产生一些反馈信息,从而训练出一个更好的模型。

Task03【datawhale组队学习】JoyRL&EasyRL

线性回归

在强化学习中,常用的深度学习模型有线性模型。
线性模型并不是深度学习模型,而是传统的机器学习模型,但它是深度学习模型的基础,在深度学习中相当于单层的神经网络。

在线性模型中,应用较为广泛的两个基础模型

  • 线性回归:主要用于解决回归问题。
  • 逻辑回归:主要用于解决分类问题,但也可用于回归问题。

以典型的房价预测问题为例,假设一套房子有m个特征,例如建造年份、房子面积等,分别记为 x 1 , x 2 ,..., x m x_1,x_2,...,x_m x1,x2,...,xm,用向量表示为式
x = [ x 1 , x 2 , … , x m ] \\boldsymbol{x} = [x_1, x_2, \\dots, x_m] x=[x1,x2,,xm]

那么房价 y y y可以表示为式
f ( x ; w , b ) = w 1x 1 + w 2x 2 + ⋯ + w mx m + b = w T x + b f(\\boldsymbol{x}; \\boldsymbol{w}, b) = w_1 x_1 + w_2 x_2 + \\cdots + w_m x_m + b = \\boldsymbol{w}^T \\boldsymbol{x} + b f(x;w,b)=w1x1+w2x2++wmxm+b=wTx+b
其中 w \\boldsymbol{w} w b b b 是模型的参数, f(x;w,b) f(\\boldsymbol{x}; \\boldsymbol{w}, b) f(x;w,b) 是模型的输出,也就是我们要预测的房价。出于简化考虑,通常我们会用一个符号 θ \\boldsymbol{\\theta} θ 来表示 w \\boldsymbol{w} w b b b,如式所示。

f θ ( x ) = θ T x f^{\\boldsymbol{\\theta}}(\\boldsymbol{x}) = \\boldsymbol{\\theta}^T \\boldsymbol{x} fθ(x)=θTx

  • 目标是求解最优参数 (\\boldsymbol{\\theta}^*),让模型能依据房屋 m m m 个特征,准确预测房价。这类问题属于拟合问题(类似用直线拟合散点,直线即模型 )。

  • 用于拟合的数据(散点)叫样本,因真实模型复杂未知,需大量样本(训练集) 训练模型。

  • 模型参数是近似求解,几乎无法找到完美拟合所有样本的模型(最优解可能不存在 )。

  • 欠拟合:模型在训练集表现就差(测试集表现可能稍好,但整体能力弱 )。 增加模型复杂度,增加特征,调整模型参数

  • 过拟合:模型在训练集表现好,但测试集表现差(学了样本“噪声”,泛化能力弱 )。 数据增强,正则化,模型选择与融合

  • 两者本质是陷入局部最优解

梯度下降

回到问题本身,这类问题的解决方法也有很多种,例如最小二乘法、牛顿法等,但目前最流行的方法还是梯度下降。其基本思想如下。

  • 初始化参数:选择一个初始点或参数的初始值。
  • 计算梯度:在当前点计算函数的梯度,即函数关于各参数的偏导数。梯度指向函数值增加最快的方向。
  • 更新参数:按照负梯度方向更新参数,这样可以减少函数值。这个过程在神经网络中一般是以反向传播算法来实现的。
  • 重复上述二三步骤,直到梯度趋近于 0 或者达到一定迭代次数。

梯度下降本质上是一种基于贪心思想的方法,它的泛化能力很强,能够基于任何可导的函数求解最优解。
假设我们要找到一个山谷中的最低点,也就是下山,那么我们可以从任意一点出发,然后沿着最陡峭的方向向下走,这样就能够找到山谷中的最低点。这里的最陡峭的方向就是梯度方向,而沿着这个方向走的步长就是学习率,这个学习率一般是一个超参数,需要我们自己来设定。
学习率:通常学习率设为较小值,可避免错过最低点,但过小将增加迭代次数才能达最低点 。
小批量样本:

  • 定义:梯度下降迭代中,用于计算梯度的小批量样本,称为一个 batch 。
  • 作用:相当于下山时的 “视野”,影响对全局情况的判断 。
  • 不同批量的影响
    • 批量太小:易因 “视野窄” 陷入局部最优(被局部山峰迷惑 )。
    • 批量太大:会消耗更多计算资源 。
      决策依据:需根据实际情况(计算资源、模型特性、数据规模等 )选择合适的 batch 大小 。

Task03【datawhale组队学习】JoyRL&EasyRL

梯度下降的分类

  • 按样本选择方式:分为单纯梯度下降和随机梯度下降(SGD)
    • 单纯梯度下降:按样本原始顺序迭代拟合参数,规则但缺乏随机性。
    • 随机梯度下降:随机抽取样本迭代,利用随机性跳出局部最优,提升收敛性与鲁棒性。
  • 按 batch 大小:分为批量梯度下降和小批量梯度下降
    • 批量梯度下降:用整个训练集迭代,优点是迭代方向准;缺点是计算开销大(batch 极大 )。
    • 小批量梯度下降:用一小部分样本迭代,优点是计算开销小;缺点是迭代方向不够准确(batch 极小 )。

通常使用小批量的随机梯度下降。这样可以兼顾到所有的优点,从而使得训练更加稳定,算法效果也会更好。

Adam自适应优化算法

不仅仅考虑了当前的梯度,还考虑了之前的梯度的平方,这样可以更加准确地估计梯度的方向,从而加快梯度下降的速度,也是目前最流行的优化器之一。
注意在做强化学习应用或研究的时候,我们并不需要太纠结于优化器的选择,因为这些优化器的效果并没有太大的差别,而且我们也不需要去了解它们的具体原理,只需要知道它们的大致作用就可以了。

逻辑回归(解决二分类问题)

虽然逻辑回归名字中带有回归,但是它是用来解决分类问题的,而不是回归问题(即预测问题)。
在分类问题中,我们的目标是预测样本的类别,而不是预测一个连续的值。
例如,我们要预测一封邮件是否是垃圾邮件,这就是一个二分类问题,通常输出 0 和 1 等离散的数字来表示对应的类别。在形式上,逻辑回归和线性回归非常相似,
如图所示(逻辑回归结构),就是在线性模型的后面增加一个sigmoid函数,我们一般称之为激活函数。
Task03【datawhale组队学习】JoyRL&EasyRL

sigmoid函数定义为式
sigmoid ( z ) = 1 1 + exp ⁡ ( − z ) \\text{sigmoid}(z) = \\frac{1}{1 + \\exp(-z)} sigmoid(z)=1+exp(z)1
如图 所示,sigmoid 函数可以将输入的任意实数映射到 (0,1) 的区间内,对其输出的值进行判断,例如小于 0.5 我们认为预测的是类别0,反之是类别 1。这样一来通过梯度下降来求解模型参数就可以用于实现二分类问题了。
注意,虽然逻辑回归只是在线性回归模型基础上增加了一个激活函数,但两个模型是完全不同的,包括损失函数等等。

  • 线性回归的损失函数是均方差损失
  • 逻辑回归模型一般是交叉熵损失
    Task03【datawhale组队学习】JoyRL&EasyRL
    逻辑回归的优点和缺点:逻辑回归的主要优点在于增加了模型的非线性能力,同时模型的参数也比较容易求解,但是它也有一些缺点,例如它的非线性能力还是比较弱的,而且它只能解决二分类问题,不能解决多分类问题。在实际应用中,我们一般会将多个二分类问题组合成一个多分类问题,例如将sigmoid函数换成softmax回归函数等。(把输出映射到多分类概率 )。

模型逻辑回归的模型结构跟生物神经网络的最小单位神经元很相似。

  1. 生物神经元:通过 “突触” 传递信号,接收多来源输入,经 “细胞核” 处理后传递。
  2. 人工神经网络(逻辑回归可看作极简版):
  • 线性加权: w T x w^Tx wTx模拟 “多输入信号的整合”;
  • 激活函数(如 sigmoid ):模拟 “神经元是否激活、传递信号” 的判断逻辑。
    Task03【datawhale组队学习】JoyRL&EasyRL
    逻辑回归这类模型的结构也比较灵活多变,可以通过横向堆叠的形式来增加模型的复杂度,例如增加隐藏层等,这样就能解决更复杂的问题,这就是接下来要讲的神经网络模型。并且,我们可以认为逻辑回归就是一个最简单的人工神经网络模型

全连接网络

如图所示,将线性层横向堆叠起来,前一层网络的所有神经元的输出都会输入到下一层的所有神经元中,这样就可以得到一个全连接网络。其中,每个线性层的输出都会经过一个激活函数(图中已略去),这样就可以增加模型的非线性能力。
Task03【datawhale组队学习】JoyRL&EasyRL
我们把这样的网络叫做全连接网络(fully connected network),也称作多层感知机(multi-layer perceptron,MLP),是最基础的深度神经网络模型。
把神经网络模型中前一层的输入向量记为 x l − 1 ∈ R d l − 1 \\boldsymbol{x}^{l-1} \\in \\mathbb{R}^{d^{l-1}} xl1Rdl1 ,其中第一层的输入也就是整个模型的输入可记为 x 0 \\boldsymbol{x}^0 x0,每一个全连接层将前一层的输入映射到 x l ∈ R d l \\boldsymbol{x}^l \\in \\mathbb{R}^{d^l} xlRdl ,也就是后一层的输入,具体定义为式。

x l = σ ( z ) , z = W x l − 1+ b = θ x l − 1 \\boldsymbol{x}^l = \\sigma(z), \\quad z = \\boldsymbol{W} \\boldsymbol{x}^{l-1} + \\boldsymbol{b} = \\boldsymbol{\\theta} \\boldsymbol{x}^{l-1} xl=σ(z),z=Wxl1+b=θxl1

其中 W∈ R d l − 1 × d l \\boldsymbol{W} \\in \\mathbb{R}^{d^{l-1} \\times d^l} WRdl1×dl 是权重矩阵, b \\boldsymbol{b} b 为偏置矩阵,与线性模型类似,这两个参数我们通常看作一个参数 θ \\boldsymbol{\\theta} θ
σ(⋅) \\sigma(\\cdot) σ() 是激活函数,除了 sigmoid 函数之外,还包括 softmax 函数、ReLU 函数和 tanh 函数等等激活函数。其中最常用的是 ReLU 函数和 tanh 函数,前者将神经元也就是线性函数的输出映射到 (0,1) (0, 1) (0,1) 之间,后者则映射到 −1 -1 1 1 1 1 之间。
激活函数的选择需要根据具体的问题来定,没有一种激活函数适用于所有的问题。
在了解到神经网络前后层的关系之后,我们就可以表示一个 l l l层的神经网络模型
第1层: x ( 1 ) = σ 1 ( W ( 1 ) x ( 0 )+ b ( 1 )) , \\boldsymbol{x}^{(1)} = \\sigma_1\\left( \\boldsymbol{W}^{(1)} \\boldsymbol{x}^{(0)} + \\boldsymbol{b}^{(1)} \\right), x(1)=σ1(W(1)x(0)+b(1)),

第2层: x ( 2 ) = σ 2 ( W ( 2 ) x ( 1 )+ b ( 2 )) , \\boldsymbol{x}^{(2)} = \\sigma_2\\left( \\boldsymbol{W}^{(2)} \\boldsymbol{x}^{(1)} + \\boldsymbol{b}^{(2)} \\right), x(2)=σ2(W(2)x(1)+b(2)),

⋮ ⋮ \\quad\\vdots\\quad\\vdots

l l l层: x ( l ) = σ l ( W ( l ) x ( l − 1 )+ b ( l )) \\boldsymbol{x}^{(l)} = \\sigma_l\\left( \\boldsymbol{W}^{(l)} \\boldsymbol{x}^{(l-1)} + \\boldsymbol{b}^{(l)} \\right) x(l)=σl(W(l)x(l1)+b(l))
从上面的式子可以看出,神经网络模型的参数包括每一层的权重矩阵和偏置矩阵
θ = { W ( 1 ) , b ( 1 ) , W ( 2 ) , b ( 2 ) , … , W ( l ) , b ( l ) } \\boldsymbol{\\theta} = \\left\\{ \\boldsymbol{W}^{(1)}, \\boldsymbol{b}^{(1)}, \\boldsymbol{W}^{(2)}, \\boldsymbol{b}^{(2)}, \\dots, \\boldsymbol{W}^{(l)}, \\boldsymbol{b}^{(l)} \\right\\} θ={W(1),b(1),W(2),b(2),,W(l),b(l)}
这些参数都是需要我们去学习的,也就是说我们需要找到一组参数使得神经网络模型的输出尽可能地接近真实值,这个过程就是神经网络的训练过程。同基础的线性模型类似,神经网络也可以通过梯度下降的方法来求解最优参数。

更高级的神经网络

**基于线性模型的神经网络已经足够适用于大部分的强化学习问题。但是对于一些更复杂更特殊的问题,我们可能需要更高级的神经网络模型来解决。这些高级的神经网络理论上能够取得更好的效果,但从实践上来看,这些模型在强化学习上的应用并不是很多,因为这些模型的训练过程往往比较复杂,需要调整的参数也比较多,而且这些模型的效果并不一定比基础的神经网络模型好很多。
因此,读者
在解决实际的强化学习问题时还是尽量简化问题,并使用基础的神经网络模型来解决。**在这里我们只是简要介绍一些常用的高级神经网络模型,感兴趣的读者可以自行深入了解。

卷积神经网络(适用于处理网格结构的数据)

适用于处理具有网格结构的数据,如图像(2D网格像素点)或时间序列数据(1D网格)等
其中图像是用得最为广泛的。比如在很多的游戏场景中,其状态输入都是以图像的形式呈现的,并且图像能够包含更多的信息,这个时候我们就可以使用卷积神经网络来处理这些图像数据。

使用卷积神经网络的时候,我们需要注意以下几个主要特点:

  • 局部感受野:传统的线性神经网络每个节点都与前一层的所有节点相连接。但在CNN中,我们使用小的局部感受野(例如3x3或5x5的尺寸),它只与前一层的一个小区域内的节点相连接。这可以减少参数数量,并使得网络能够专注于捕捉局部特征。
  • 权重共享:在同一层的不同位置,卷积核的权重是共享的,这不仅大大减少了参数数量,还能帮助网络在图像的不同位置检测同样的特征。
  • 池化层:池化层常常被插入在连续的卷积层之间,用来减少特征图的尺寸、减少参数数量并提高网络的计算效率。最常见的池化操作是最大池化( Max-Pooling),它将输入特征图划分为若干个小区域,并输出每个区域的最大值。
  • 归一化和Dropout :为了优化网络的性能和防止过拟合,可以在网络中添加归一化层(如 Batch Normalization)和 Dropout 。

循环神经网络(适用于处理序列数据)

循环神经网络适用于处理序列数据,也是最基础的一类时序网络。
在强化学习中,循环神经网络常常被用来处理序列化的状态数据,例如在Atari游戏中,我们可以将连续的四帧图像作为一个序列输入到循环神经网络中,这样一来就能够更好地捕捉到游戏中的动态信息。

但是基础的RNN结构很容易产生梯度消失或者梯度爆炸的问题,因此我们通常会使用一些改进的循环神经网络结构,例如
LSTM和GRU等。
LSTM 主要是通过引入门机制(输入门、遗忘门和输出门)来解决梯度消失的问题,它能够在长序列中维护更长的依赖关系。
GRU则是对 LSTM 的简化,它只有两个门(更新门和重置门),并且将记忆单元和隐藏状态合并为一个状态向量,性能与 LSTM相当,但通常计算效率更高。

还有一种特殊的结构,叫做 Transformer。虽然它也是为了处理序列数据而设计的,但是是一个完全不同的结构,不再依赖循环来处理序列,而是使用自注意机制 (self-attention mechanism) 来同时考虑序列中的所有元素。并且Transformer的设计特别适合并行计算,使得训练速度更快。自从被提出以后,Transformer就被广泛应用于自然语言处理领域,例如BERT以及现在特别流行的GPT等模型。


简化过程:原始线性模型是:
f ( x ; w , b ) = w T x + b f(\\boldsymbol{x}; \\boldsymbol{w}, b) = \\boldsymbol{w}^T \\boldsymbol{x} + b f(x;w,b)=wTx+b

如果我们对输入向量 x \\boldsymbol{x} x 做一个**“增广”处理**(给 x \\boldsymbol{x} x 末尾补一个固定值 1 ),定义新的增广向量:
x ~= [ x 1 ] = [ x 1 , x 2 , … , x m , 1 ] \\boldsymbol{\\tilde{x}} = \\begin{bmatrix} \\boldsymbol{x} \\\\ 1 \\end{bmatrix} = [x_1, x_2, \\dots, x_m, 1] x~=[x1]=[x1,x2,,xm,1]

同时,把参数 w \\boldsymbol{w} w b b b 也合并成一个增广参数向量
θ = [ w b ] = [ w 1 , w 2 , … , w m , b ] \\boldsymbol{\\theta} = \\begin{bmatrix} \\boldsymbol{w} \\\\ b \\end{bmatrix} = [w_1, w_2, \\dots, w_m, b] θ=[wb]=[w1,w2,,wm,b]

此时,原模型的内积 w T x+b \\boldsymbol{w}^T \\boldsymbol{x} + b wTx+b 就可以写成增广向量的内积
w T x + b = [ w 1 , w 2 , … , w m , b][ x 1 x 2 ⋮ x m 1 ] = θ T x ~ \\boldsymbol{w}^T \\boldsymbol{x} + b = \\begin{bmatrix} w_1, w_2, \\dots, w_m, b \\end{bmatrix} \\begin{bmatrix} x_1 \\\\ x_2 \\\\ \\vdots \\\\ x_m \\\\ 1 \\end{bmatrix} = \\boldsymbol{\\theta}^T \\boldsymbol{\\tilde{x}} wTx+b=[w1,w2,,wm,b] x1x2xm1 =θTx~

反向传播算法:
反向传播算法(Backpropagation Algorithm)是深度学习中训练神经网络的核心算法,通过计算损失函数对每个参数的梯度(即参数变化对损失的影响程度),再用梯度下降法更新参数(如 “沿着梯度反方向调整参数以减小损失”)。

核心思想:根据输出层得到的误差,从输出层反向传播到输入层,逐层计算误差对每个神经元参数(权重和偏置)的梯度,再利用梯度下降法更新这些参数,不断降低误差,让模型预测值与真实值更加接近。
数学基础:运用链式法则来计算复合函数的导数。在神经网络中,每个神经元的输出都是其输入的函数,整个网络可以看作是一个非常复杂的复合函数,通过链式法则能够有效计算出误差关于每个参数的梯度。

梯度的计算过程是 “链式法则” 的传递 —— 深层网络中,上层参数的梯度需要通过下层的梯度 “逐层传递”。若传递过程中梯度被不断 “缩小” 或 “放大”,就会导致梯度消失或爆炸。
梯度消失:在反向传播时,梯度值随着网络层数的增加而逐渐缩小至接近 0,导致浅层(靠近输入层)的参数几乎无法更新(梯度接近 0,参数调整幅度为 0)。
梯度爆炸:在反向传播时,梯度值随着网络层数的增加而急剧增大至极大值(如 1000 以上),导致参数更新幅度过大,模型权重变得极不稳定(甚至出现 NaN 值)。

维度 梯度消失 梯度爆炸 梯度值变化 随层数增加逐渐缩小至接近0 随层数增加急剧增大至极大值 核心原因 激活函数导数≤1 + 权重≤1的连乘 权重>1的连乘(尤其RNN时间步累积) 对参数的影响 浅层参数几乎不更新 参数更新幅度过大,权重不稳定 对训练的影响 模型收敛缓慢、精度低 损失震荡、无法收敛、计算溢出

LSTM(长短期记忆网络)和GRU(门控循环单元)是两种改进的循环神经网络(RNN),专门用于解决传统RNN在处理长序列数据时的“梯度消失”或“梯度爆炸”问题,能更好地捕捉序列中的长期依赖关系(比如文本中上下文的远距离关联、时间序列中过去事件对当前的影响等)。
传统RNN的结构是“循环”的——每一步的输出会作为下一步的输入,理论上能记住历史信息。但当序列过长(比如一篇长文章、一段持续多天的时间序列),早期的信息会在传递中逐渐“遗忘”(梯度消失),导致模型无法学习到长期依赖。
因此,LSTM和GRU通过设计“门控机制”,主动控制信息的“保留”“更新”和“遗忘”,解决了这一问题。
LSTM是1997年提出的经典改进模型,核心是通过3个“门”来管理信息,结构相对复杂但功能完整。

  • 细胞状态(Cell State):类似“传送带”,是LSTM的核心,负责长期信息的稳定传递(不受短期干扰)。
  • 遗忘门(Forget Gate):决定“哪些历史信息需要被遗忘”(比如处理句子时,前半句的主语已过时,就会被遗忘)。
    用sigmoid函数输出0-1之间的概率(0表示完全遗忘,1表示完全保留)。
  • 输入门(Input Gate):决定“哪些新信息需要被存入细胞状态”(比如新出现的关键名词)。
    同样用sigmoid控制“是否存入”,再通过tanh生成待存入的新信息。
  • 输出门(Output Gate):决定“当前细胞状态中哪些信息需要输出到下一个时间步”(比如根据当前上下文,只输出与下一个词相关的信息)。
    用sigmoid控制输出比例,再结合细胞状态的tanh变换生成最终输出。
  • 优势:对长期依赖的捕捉能力强,稳定性高。
  • 适用场景:需要精确处理长序列的任务,如机器翻译(两种语言的长句对应)、文本摘要(提炼长文核心)、语音识别(长语音转文字)。
    GRU是2014年在LSTM基础上简化的模型,保留了核心功能但结构更简单,训练速度更快。
    GRU去掉了LSTM的“细胞状态”,直接用“隐藏状态”传递信息,并用2个门替代了3个门:
  • 更新门(Update Gate):融合了LSTM的“遗忘门”和“输入门”,决定“保留多少历史信息”和“接收多少新信息”(比如同时判断是否遗忘旧主语、是否存入新主语)。
  • 重置门(Reset Gate):决定“是否忽略历史信息,只关注当前输入”(比如遇到转折词时,暂时忽略之前的信息)。
  • 优势:参数更少(比LSTM少1/3左右),训练速度更快,在数据量有限或需要高效计算时表现更好。
  • 适用场景:对速度要求较高的任务,如文本分类(短文本快速判断类别)、情感分析(实时分析用户评论情绪)、时间序列预测(如股票价格短期预测)。
维度 LSTM GRU 门控数量 3个(遗忘、输入、输出) 2个(更新、重置) 参数数量 更多(结构复杂) 更少(简化结构) 训练速度 较慢(参数多,计算成本高) 较快(参数少,计算效率高) 长期依赖捕捉 更稳定(适合极长序列) 稍弱但足够(适合中长序列) 适用场景 长序列、高精度需求(如机器翻译) 中短序列、高效需求(如情感分析)

LSTM和GRU都是为解决RNN的长期依赖问题而生,核心是“门控机制”。

  • 追求精度和长序列处理能力 → 选LSTM;
  • 追求速度和效率,且序列不算极长 → 选GRU。
    两者在NLP(自然语言处理)、时间序列分析等地方应用广泛,是深度学习处理序列数据的核心工具。