【Torch】nn.GRU算法详解_nn.gru(
1. 输入输出
-
输入张量
-
可选初始隐状态
- 形状:
(num_layers * num_directions, batch_size, hidden_size) - 默认为全零张量。如果要自定义,需提供此形状的
h0。
- 形状:
-
输出
调用output, h_n = gru(x, h0)返回两部分:output:所有时间步的隐藏状态序列- 形状:
- 默认:
(seq_len, batch_size, num_directions * hidden_size) batch_first=True:(batch_size, seq_len, num_directions * hidden_size)
- 默认:
- 含义:每个时间步的隐藏状态,可以直接接全连接或其它后续层。
- 形状:
h_n:最后一个时间步的隐藏状态- 形状:
(num_layers * num_directions, batch_size, hidden_size) - 含义:每一层(及方向)在序列末尾的隐藏状态,常用于初始化下一个序列或分类任务。
- 形状:
2. 构造函数参数详解
nn.GRU( input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False)
input_sizehidden_sizehidden_size)。num_layersbiasFalse 时,所有线性变换均无 bias。batch_firstTrue),默认序列维在最前(False)。dropoutnum_layers>1 时生效。bidirectionalTrue,则隐状态和输出维度翻倍。3. 输出含义详解
-
output- 大小:
[..., num_directions * hidden_size] - 如果
bidirectional=False,num_directions=1;否则=2。 output[t, b, :](或在batch_first模式下output[b, t, :])表示第 t 步第 b 个样本的隐藏状态。
- 大小:
-
h_n- 大小:
(num_layers * num_directions, batch_size, hidden_size) - 维度索引含义:
- 维度 0:层数 × 方向(例如 3 层双向时索引 0–5,对应层1正向、层1反向、层2正向…)
- 维度 1:批内样本索引
- 维度 2:隐藏状态向量
- 大小:
4. 使用注意事项
-
batch_first的选择- 若后续直接接全连接层、BatchNorm 等,更习惯
batch_first=True;否则可用默认格式节省一次转置。
- 若后续直接接全连接层、BatchNorm 等,更习惯
-
双向与输出维度
bidirectional=True时,output的最后一维和h_n中hidden_size均会翻倍,需要相应修改下游网络维度。
-
Dropout 的生效条件
- 只有在
num_layers > 1并且dropout > 0时,才会在各层间插入 Dropout;单层时不会应用。
- 只有在
-
初始隐状态
- 默认为零。若在两个连续序列之间保持状态(stateful RNN),可将上一次的
h_n作为下一次的h0。
- 默认为零。若在两个连续序列之间保持状态(stateful RNN),可将上一次的
-
PackedSequence
- 对变长序列,可用
torch.nn.utils.rnn.pack_padded_sequence输入,输出再用pad_packed_sequence恢复,对长短不一的序列批处理很有用。
- 对变长序列,可用
-
性能与稳定性
- GRU 相比 LSTM 参数更少、速度稍快,但有时在长期依赖或梯度流问题上略不如 LSTM。
- 可在多层 RNN 之间加 LayerNorm 或 Residual 连接,提升深度模型的收敛和稳定性。
简单示例
import torch, torch.nn as nn# 定义单层单向 GRUgru = nn.GRU(input_size=10, hidden_size=20, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)# 输入:batch=8, seq_len=15, features=10x = torch.randn(8, 15, 10)# 默认 h0 为零output, h_n = gru(x)print(output.shape) # (8, 15, 2*20) 双向,所以 hidden_size*2print(h_n.shape) # (2*2, 8, 20) num_layers=2, num_directions=2


