MMRotate ReDet ReFPN 报错 `assert input.type == self.in_type`
在跑实验时,使用 configs/redet/redet_re50_refpn_1x_dota_le90.py
,结果报错:
Traceback (most recent call last): File \"H:/Workspace/DeepLearning/mmrotate/tools/train.py\", line 196, in <module> main() File \"H:/Workspace/DeepLearning/mmrotate/tools/train.py\", line 183, in main train_detector( File \"h:\\workspace\\deeplearning\\mmrotate\\mmrotate\\apis\\train.py\", line 145, in train_detector runner.run(data_loaders, cfg.workflow) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmcv\\runner\\epoch_based_runner.py\", line 136, in run epoch_runner(data_loaders[i], **kwargs) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmcv\\runner\\epoch_based_runner.py\", line 53, in train self.run_iter(data_batch, train_mode=True, **kwargs) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmcv\\runner\\epoch_based_runner.py\", line 31, in run_iter outputs = self.model.train_step(data_batch, self.optimizer, File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmcv\\parallel\\data_parallel.py\", line 77, in train_step return self.module.train_step(*inputs[0], **kwargs[0]) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmdet\\models\\detectors\\base.py\", line 248, in train_step losses = self(**data) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl return forward_call(*input, **kwargs) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmcv\\runner\\fp16_utils.py\", line 119, in new_func return old_func(*args, **kwargs) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmdet\\models\\detectors\\base.py\", line 172, in forward return self.forward_train(img, img_metas, **kwargs) File \"h:\\workspace\\deeplearning\\mmrotate\\mmrotate\\models\\detectors\\two_stage.py\", line 127, in forward_train x = self.extract_feat(img) File \"h:\\workspace\\deeplearning\\mmrotate\\mmrotate\\models\\detectors\\two_stage.py\", line 69, in extract_feat x = self.neck(x) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl return forward_call(*input, **kwargs) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\mmcv\\runner\\fp16_utils.py\", line 119, in new_func return old_func(*args, **kwargs) File \"h:\\workspace\\deeplearning\\mmrotate\\mmrotate\\models\\necks\\re_fpn.py\", line 298, in forward laterals = [ File \"h:\\workspace\\deeplearning\\mmrotate\\mmrotate\\models\\necks\\re_fpn.py\", line 299, in <listcomp> self.lateral_convs[i](inputs[i + self.start_level]) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl return forward_call(*input, **kwargs) File \"h:\\workspace\\deeplearning\\mmrotate\\mmrotate\\models\\necks\\re_fpn.py\", line 148, in forward x = self.conv(x) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl return forward_call(*input, **kwargs) File \"D:\\Environments\\Anaconda3\\envs\\openmmlab\\lib\\site-packages\\e2cnn\\nn\\modules\\r2_conv\\r2convolution.py\", line 326, in forward assert input.type == self.in_typeAssertionError
按照以下提示修改 mmrotate/models/necks/re_fpn.py
三处地方,其余地方不变。
# 1. 引入 build_enn_divide_feature 函数from ..utils import ( build_enn_divide_feature, build_enn_feature, build_enn_norm_layer, ennConv, ennInterpolate, ennMaxPool, ennReLU)class ConvModule(enn.EquivariantModule): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=\'auto\', conv_cfg=None, norm_cfg=None, activation=\'relu\', inplace=False, order=(\'conv\', \'norm\', \'act\')): super(ConvModule, self).__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict) # 2. 用 build_enn_divide_feature 替换 build_enn_feature self.in_type = build_enn_divide_feature(in_channels) self.out_type = build_enn_divide_feature(out_channels) # 后续保持不变... def forward(self, x, activate=True, norm=True): \"\"\"Forward function of ConvModule.\"\"\" # 3. 如果传入的是普通 Tensor,则封装为 GeometricTensor if isinstance(x, torch.Tensor): x = enn.GeometricTensor(x, self.in_type) for layer in self.order: if layer == \'conv\': x = self.conv(x) elif layer == \'norm\' and norm and self.with_norm: x = self.norm(x) elif layer == \'act\' and activate and self.with_activatation: x = self.activate(x) return x