【map计算】自定义map计算
任务说明
我现在需要进行模型预测效果的评估,我选择map进行评估。
我的输入包含两个文件夹,一个是标准答案,里面有xml。
另外一个是你的答案,里面有你预测的xml。
模型会逐个object的进行对比,你可以自己把标准答案复制一份当做你的答案,然后进行测试,这样得分会是1.0000 。当你进行了一些修改后,得分会下降。
还可以指定针对哪些类别计算map。
我来帮你创建一个计算mAP的脚本,用于遍历文件夹a和b,并计算目标检测的综合得分。我创建了一个完整的mAP计算脚本,主要功能包括:
核心功能
- 文件夹遍历: 自动遍历文件夹a(预测结果)和文件夹b(真值标注)
- 多IoU阈值计算: 支持mAP@0.5、mAP@0.75、mAP@0.9的计算
- 综合得分: 按照你指定的权重公式计算:
0.6*mAP50 + 0.2*mAP75 + 0.2*mAP90
关键特性
- IoU计算: 实现了标准的边界框IoU计算
- AP计算: 使用11点插值法计算Average Precision
- 多格式支持: 支持COCO JSON格式的标注文件
- 类别筛选: 可以指定目标类别列表进行计算
使用方法
-
修改配置:
folder_a = \"你的预测结果文件夹路径\"folder_b = \"你的真值标注文件夹路径\" target_classes = [1, 2, 3, 4, 5] # 修改为你的目标类别ID
-
文件格式要求:
- 支持COCO JSON格式
- 预测结果需要包含置信度分数
- 边界框格式:[x1, y1, x2, y2]
-
运行脚本:
python map_calculator.py
代码
import osimport xml.etree.ElementTree as ETimport numpy as npfrom collections import defaultdictimport globdef calculate_iou(box1, box2): \"\"\" 计算两个边界框的IoU box格式: [xmin, ymin, xmax, ymax] \"\"\" x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) if x2 <= x1 or y2 <= y1: return 0.0 intersection = (x2 - x1) * (y2 - y1) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) union = area1 + area2 - intersection return intersection / union if union > 0 else 0.0def parse_xml_annotation(xml_path): \"\"\" 解析XML标注文件 (支持PASCAL VOC格式和自定义格式) 返回格式: {\'filename\': str, \'objects\': [{\'name\': str, \'bbox\': [xmin, ymin, xmax, ymax], \'confidence\': float}]} \"\"\" try: tree = ET.parse(xml_path) root = tree.getroot() annotation = { \'filename\': \'\', \'objects\': [] } # 获取文件名 filename_elem = root.find(\'filename\') if filename_elem is not None: annotation[\'filename\'] = filename_elem.text else: # 如果没有filename标签,使用xml文件名 annotation[\'filename\'] = os.path.splitext(os.path.basename(xml_path))[0] # 解析所有object for obj in root.findall(\'object\'): obj_data = {} # 获取类别名称 name_elem = obj.find(\'name\') if name_elem is not None: obj_data[\'name\'] = name_elem.text else: continue # 跳过没有名称的对象 # 获取边界框坐标 bndbox = obj.find(\'bndbox\') if bndbox is not None: xmin = float(bndbox.find(\'xmin\').text) ymin = float(bndbox.find(\'ymin\').text) xmax = float(bndbox.find(\'xmax\').text) ymax = float(bndbox.find(\'ymax\').text) obj_data[\'bbox\'] = [xmin, ymin, xmax, ymax] else: continue # 跳过没有边界框的对象 # 获取置信度 (预测结果才有,真值默认为1.0) confidence_elem = obj.find(\'confidence\') if confidence_elem is not None: obj_data[\'confidence\'] = float(confidence_elem.text) else: obj_data[\'confidence\'] = 1.0 # 真值标注默认置信度为1.0 annotation[\'objects\'].append(obj_data) return annotation except Exception as e: print(f\"解析XML文件失败 {xml_path}: {e}\") return Nonedef calculate_ap_at_iou(predictions, ground_truths, iou_threshold=0.5): \"\"\" 计算特定IoU阈值下的AP \"\"\" if not predictions: return 0.0 if not ground_truths: return 0.0 # 按置信度排序预测结果 predictions = sorted(predictions, key=lambda x: x[\'confidence\'], reverse=True) tp = np.zeros(len(predictions)) fp = np.zeros(len(predictions)) # 记录已匹配的ground truth matched_gt = set() for i, pred in enumerate(predictions): best_iou = 0 best_gt_idx = -1 # 找到最佳匹配的ground truth for j, gt in enumerate(ground_truths): if j in matched_gt: continue iou = calculate_iou(pred[\'bbox\'], gt[\'bbox\']) if iou > best_iou: best_iou = iou best_gt_idx = j # 判断是否为真正例 if best_iou >= iou_threshold and best_gt_idx != -1: tp[i] = 1 matched_gt.add(best_gt_idx) else: fp[i] = 1 # 计算累积的precision和recall tp_cumsum = np.cumsum(tp) fp_cumsum = np.cumsum(fp) recalls = tp_cumsum / len(ground_truths) if len(ground_truths) > 0 else np.zeros_like(tp_cumsum) precisions = np.divide(tp_cumsum, (tp_cumsum + fp_cumsum), out=np.zeros_like(tp_cumsum), where=(tp_cumsum + fp_cumsum) != 0) # 计算AP (使用11点插值法) ap = 0 for t in np.arange(0, 1.1, 0.1): if np.sum(recalls >= t) == 0: p = 0 else: p = np.max(precisions[recalls >= t]) ap += p / 11 return apdef load_xml_annotations(folder_path): \"\"\" 加载文件夹中所有XML标注文件 \"\"\" annotations = {} xml_files = glob.glob(os.path.join(folder_path, \"*.xml\")) for xml_file in xml_files: annotation = parse_xml_annotation(xml_file) if annotation: # 使用文件名作为key,去掉扩展名 base_name = os.path.splitext(os.path.basename(xml_file))[0] annotations[base_name] = annotation return annotationsdef calculate_map_for_folders(pred_folder, gt_folder, target_classes, iou_thresholds=[0.5, 0.75, 0.9]): \"\"\" 计算两个文件夹中XML标注的mAP \"\"\" print(f\"加载预测结果文件夹: {pred_folder}\") pred_annotations = load_xml_annotations(pred_folder) print(f\"加载真值标注文件夹: {gt_folder}\") gt_annotations = load_xml_annotations(gt_folder) print(f\"找到预测文件: {len(pred_annotations)} 个\") print(f\"找到真值文件: {len(gt_annotations)} 个\") results = {} # 为每个IoU阈值计算mAP for iou_thresh in iou_thresholds: print(f\"\\n计算 mAP@{iou_thresh}...\") class_aps = {} for class_name in target_classes: all_predictions = [] all_ground_truths = [] # 遍历所有匹配的文件 common_files = set(pred_annotations.keys()) & set(gt_annotations.keys()) for file_key in common_files: pred_data = pred_annotations[file_key] gt_data = gt_annotations[file_key] # 提取当前类别的预测和真值 pred_class = [obj for obj in pred_data[\'objects\'] if obj[\'name\'] == class_name] gt_class = [obj for obj in gt_data[\'objects\'] if obj[\'name\'] == class_name] all_predictions.extend(pred_class) all_ground_truths.extend(gt_class) # 计算该类别的AP ap = calculate_ap_at_iou(all_predictions, all_ground_truths, iou_thresh) class_aps[class_name] = ap print(f\" {class_name}: AP = {ap:.4f} (预测:{len(all_predictions)}, 真值:{len(all_ground_truths)})\") # 计算mAP mean_ap = np.mean(list(class_aps.values())) if class_aps else 0.0 results[f\'mAP{int(iou_thresh*100)}\'] = mean_ap results[f\'class_aps_{int(iou_thresh*100)}\'] = class_aps print(f\" mAP@{iou_thresh}: {mean_ap:.4f}\") return resultsdef main(): # 配置参数 folder_a = r\"E:\\评分\\我的答案\" # 预测结果XML文件夹 folder_b = r\"E:\\评分\\标准答案\" # 真值标注XML文件夹 # 目标类别列表 - 根据你的数据集修改 target_classes = [ \"person\", \"car\", \"bicycle\", \"motorcycle\", \"bus\", \"truck\", \"traffic_light\", \"stop_sign\", \"dog\", \"cat\",\"021_tdhj_xxshywyh_sh/yw_yw\" ] print(\"=== XML格式mAP计算工具 ===\") print(f\"预测文件夹: {folder_a}\") print(f\"真值文件夹: {folder_b}\") print(f\"目标类别: {target_classes}\") # 检查文件夹是否存在 if not os.path.exists(folder_a): print(f\"错误: 预测文件夹 \'{folder_a}\' 不存在!\") return if not os.path.exists(folder_b): print(f\"错误: 真值文件夹 \'{folder_b}\' 不存在!\") return # 计算各IoU阈值下的mAP results = calculate_map_for_folders(folder_a, folder_b, target_classes) # 提取主要指标 mAP50 = results.get(\'mAP50\', 0.0) mAP75 = results.get(\'mAP75\', 0.0) mAP90 = results.get(\'mAP90\', 0.0) # 计算综合得分 comprehensive_score = 0.6 * mAP50 + 0.2 * mAP75 + 0.2 * mAP90 # 计算所有分类的总得分 total_score_50 = sum(results.get(\'class_aps_50\', {}).values()) total_score_75 = sum(results.get(\'class_aps_75\', {}).values()) total_score_90 = sum(results.get(\'class_aps_90\', {}).values()) # 输出最终结果 print(\"\\n\" + \"=\"*60) print(\"最终结果汇总\") print(\"=\"*60) print(f\"mAP@0.5: {mAP50:.4f}\") print(f\"mAP@0.75: {mAP75:.4f}\") print(f\"mAP@0.9: {mAP90:.4f}\") print(\"-\" * 40) print(f\"综合得分: {comprehensive_score:.4f}\") print(f\"权重配置: 0.6*mAP50 + 0.2*mAP75 + 0.2*mAP90\") print(\"-\" * 40) print(\"所有分类总得分:\") print(f\" 总AP@0.5: {total_score_50:.4f}\") print(f\" 总AP@0.75: {total_score_75:.4f}\") print(f\" 总AP@0.9: {total_score_90:.4f}\") print(f\" 加权总得分: {0.6*total_score_50 + 0.2*total_score_75 + 0.2*total_score_90:.4f}\") print(\"=\"*60) # # 详细的每类别结果# print(\"\\n各类别详细结果:\")# for class_name in target_classes:# print(f\"\\n{class_name}:\")# for thresh in [50, 75, 90]:# key = f\'class_aps_{thresh}\'# if key in results and class_name in results[key]:# ap = results[key][class_name]# print(f\" AP@0.{thresh//10 if thresh != 50 else \'5\'}: {ap:.4f}\") # 保存结果到文件 output_results = { \'mAP50\': mAP50, \'mAP75\': mAP75, \'mAP90\': mAP90, \'comprehensive_score\': comprehensive_score, \'total_scores\': { \'total_AP50\': total_score_50, \'total_AP75\': total_score_75, \'total_AP90\': total_score_90, \'weighted_total_score\': 0.6*total_score_50 + 0.2*total_score_75 + 0.2*total_score_90 }, \'target_classes\': target_classes, \'weights\': {\'mAP50\': 0.6, \'mAP75\': 0.2, \'mAP90\': 0.2}, \'class_details\': { \'mAP50_classes\': results.get(\'class_aps_50\', {}), \'mAP75_classes\': results.get(\'class_aps_75\', {}), \'mAP90_classes\': results.get(\'class_aps_90\', {}) } } import json with open(\'xml_map_results.json\', \'w\', encoding=\'utf-8\') as f: json.dump(output_results, f, indent=2, ensure_ascii=False) print(f\"\\n详细结果已保存到: xml_map_results.json\")if __name__ == \"__main__\": main()