> 技术文档 > PyTorch中实现早停机制(EarlyStopping)附代码

PyTorch中实现早停机制(EarlyStopping)附代码

1. 核心目的

  • 当模型在验证集上的性能不再提升时,提前终止训练
  • 防止过拟合,节省计算资源

2. 实现方法

监控验证集指标(如损失、准确率),设置耐心值(Patience)

3. 代码:

class EarlyStopping: def __init__(self,patience =10,delta=0): \"\"\" Early stopping Args: patience: int, number of epochs to wait before stopping delta: float, the minimum improvements \"\"\" self.patience = patience self.delta = delta self.counter =0 self.early_stop = False self.best_loss = float(\'inf\') def __call__(self, val_loss): if val_loss < self.best_loss - self.delta: self.best_loss = val_loss self.counter =0 else: self.counter+=1 if self.counter >= self.patience: self.early_stop = True 

解释__call__ 方法的作用
在 Python 中,当一个类定义了 __call__ 方法时,这个类的实例就可以被当作函数来调用。例如:

early_stopper = EarlyStopping(patience=3) # 创建实例early_stopper(val_loss=0.5) # 调用实例,实际执行 __call__ 方法