> 文档中心 > PPOCRLabel格式的数据集操作总结。

PPOCRLabel格式的数据集操作总结。


1、生成识别数据

获取PPOCRLabel格式的数据集中的目标的四个点,然后使用getPerspectiveTransform和warpPerspective获取图片,生成识别数据集。

import jsonimport osimport numpy as npimport cv2def get_rotate_crop_image(img, points):    '''    img_height, img_width = img.shape[0:2]    left = int(np.min(points[:, 0]))    right = int(np.max(points[:, 0]))    top = int(np.min(points[:, 1]))    bottom = int(np.max(points[:, 1]))    img_crop = img[top:bottom, left:right, :].copy()    points[:, 0] = points[:, 0] - left    points[:, 1] = points[:, 1] - top    '''    assert len(points) == 4, "shape of points must be 4*2"    img_crop_width = int( max(     np.linalg.norm(points[0] - points[1]),     np.linalg.norm(points[2] - points[3])))    img_crop_height = int( max(     np.linalg.norm(points[0] - points[3]),     np.linalg.norm(points[1] - points[2])))    pts_std = np.float32([[0, 0], [img_crop_width, 0],     [img_crop_width, img_crop_height],     [0, img_crop_height]])    M = cv2.getPerspectiveTransform(points, pts_std)    dst_img = cv2.warpPerspective( img, M, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC)    dst_img_height, dst_img_width = dst_img.shape[0:2]    if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img)    return dst_imgdef write_txt_img(src_path,label_txt,file_dir):    with open(src_path, 'r', encoding='utf-8') as f: for line in f.readlines():     print(line)     content = line.split('\t')     print(content[0])     imag_name = content[0].split('/')[1]     image_path = './train_data/icdar2015/text_localization/' + content[0]     img = cv2.imread(image_path)     content[1] = content[1].replace("'", "\"")     list_dict = json.loads(content[1])     nsize = len(list_dict)     print(nsize)     num = 0     for i in range(nsize):  print(list_dict[i])  lin = list_dict[i]  info = lin['transcription']  info=info.replace(" ","")  points = lin['points']  points = [list(x) for x in points]  points = np.float32([list(map(float, item)) for item in points])  imag_name=str(num)+"_"+imag_name  save_path = './train_data/rec/' +file_dir+ imag_name  dst_img = get_rotate_crop_image(img, points)  cv2.imwrite(save_path, dst_img)  label_txt.write(file_dir+imag_name+'\t'+info+'\n')  num=num+1if not os.path.exists('train_data/rec/train/'):    os.makedirs('train_data/rec/train/')if not os.path.exists('train_data/rec/val/'):    os.makedirs('train_data/rec/val/')src_path = r"./train_data/icdar2015/text_localization/train.txt"label_txt=r"./train_data/rec/train.txt"src_test_path = r"./train_data/icdar2015/text_localization/val.txt"label_test_txt=r"./train_data/rec/val.txt"with open(label_txt, 'w') as w_label:    write_txt_img(src_path,w_label,'train/')with open(label_test_txt, 'w') as w_label:    write_txt_img(src_test_path, w_label,'val/')

2、切分训练集和验证集

按照一定的比例,将数据集切分为训练集和验证集

# 制作数据集,将Label.txt切分为训练集和验证集import osimport shutilfrom sklearn.model_selection import train_test_splitos.makedirs('train',exist_ok=True)os.makedirs('val',exist_ok=True)label_txt='Label.txt'with open(label_txt, 'r',encoding='gbk') as f:   txt_List=f.readlines()   trainval_files, val_files = train_test_split(txt_List, test_size=0.2, random_state=42)   train_list=[]   for file_Line in trainval_files:image_path = file_Line.split('\t')[0]image_new_path='train/'+image_path.split('/')[1]+'\t'+file_Line.split('\t')[1]train_list.append(image_new_path)   f = open("train.txt", "w")   f.writelines(train_list)   f.close()   val_list = []   for file_Line in val_files:image_path = file_Line.split('\t')[0]image_new_path = 'val/' + image_path.split('/')[1] + '\t' + file_Line.split('\t')[1]val_list.append(image_new_path)   f = open("val.txt", "w")   f.writelines(val_list)   f.close()   for txt in trainval_files:image_name=txt.split('\t')[0]new_path="./train/"+image_name.split('/')[1]shutil.copy(image_name, new_path)print(image_name)   for txt in val_files:image_name=txt.split('\t')[0]new_path="./val/"+image_name.split('/')[1]shutil.copy(image_name, new_path)print(image_name)

3、将数据集生成LabelImg格式

将PPOCRLabel格式的数据集转为LabelImg标注的xml格式的数据集。

import osfrom collections import defaultdictimport cv2# import misc_utils as utils  # pip3 install utils-misc==0.0.5 -i https://pypi.douban.com/simple/import jsonos.makedirs('./Annotations', exist_ok=True)print('建立Annotations目录', 3)# os.makedirs('./PaddleOCR/train_data/ImageSets/Main', exist_ok=True)# print('建立ImageSets/Main目录', 3)mem = defaultdict(list)with open('Label.txt', 'r', encoding='utf8') as fp:    s = [i.replace('\n','').split('\t') for i in fp.readlines()]    for i in enumerate(s): path = i[1][0] anno = json.loads(i[1][1]) print(anno) filename = path.split('/')[1] img = cv2.imread(path) cv2.imwrite('Annotations/'+filename.split('.')[0]+'.jpg',img) height, width = img.shape[:-1] for j in range(len(anno)):     label = 'No'     x1 = min(int(anno[j - 1]['points'][0][0]), int(anno[j - 1]['points'][1][0]),int(anno[j - 1]['points'][2][0]), int(anno[j - 1]['points'][3][0]))     x2 = max(int(anno[j - 1]['points'][0][0]), int(anno[j - 1]['points'][1][0]),int(anno[j - 1]['points'][2][0]), int(anno[j - 1]['points'][3][0]))     y1 = min(int(anno[j - 1]['points'][0][1]), int(anno[j - 1]['points'][1][1]),int(anno[j - 1]['points'][2][1]), int(anno[j - 1]['points'][3][1]))     y2 = max(int(anno[j - 1]['points'][0][1]), int(anno[j - 1]['points'][1][1]),int(anno[j - 1]['points'][2][1]), int(anno[j - 1]['points'][3][1]))     mem[filename].append([label, x1, y1, x2, y2])     # for i, filename in enumerate(mem):     #     img = cv2.imread(os.path.join('train', filename))     # height, width, _ = img.shape     with open(os.path.join('./Annotations', filename.split('.')[0]) + '.xml', 'w') as f:  f.write(f"""     JPEGImages     {filename.split('.')[0]}.jpg       {width}  {height}  3          0\n""")  for label, x1, y1, x2, y2 in mem[filename]:      f.write(f"""      {label}  Unspecified  0  0        {x1}      {y1}      {x2}      {y2}       \n""")  f.write("")

4、将PPOCRLabel格式的数据集转为DBNet训练用的icdar2015格式的数据集

import osimport jsondef json_2_icdar(js_path, ic_path):    with open(js_path, 'r', encoding='utf-8') as f: for line in f.readlines():     print(line)     content = line.split('\t')     print(content[0])     txt_file = str(content[0].split('.')[0])+'.txt'     dst_file = os.path.join(ic_path, txt_file)     # write file     file_lineinfo = open(txt_file, 'w', encoding='utf-8')     list_dict = json.loads(content[1])     nsize = len(list_dict)     print(nsize)     for i in range(nsize):  print(list_dict[i])  lin = list_dict[i]  info = lin['transcription']  points = lin['points']  points = [int(y) for x in points for y in x]  pts = ','.join(map(str, points))  lineinfo = pts + ',' + info + '\n'  file_lineinfo.write(lineinfo)     file_lineinfo.close()if __name__ == "__main__":    src_path = r"train/Label.txt"    dst_path = r""    json_2_icdar(src_path, dst_path)

5、数据增强

对标注的数据集做旋转、高斯模糊、色彩饱和度、亮度等增强。

import jsonimport osimport cv2import numpy as npimport torchvision.transforms as transformsfrom torchtoolbox.transform import Cutoutfrom PIL import Imagefrom random import randint# 数据预处理7t=[    transforms.ColorJitter(brightness=0.3, contrast=0.5, saturation=0.5),    transforms.GaussianBlur(5,sigma=(0.1,0.5)),    ]transform = transforms.Compose([    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.2),    transforms.GaussianBlur(5, sigma=(0.1, 3.0)),    transforms.ToTensor(),    transforms.ToPILImage(),])"""旋转后图片返回"""def dumpRotateImage(img, degree): #图片,角度    height, width = img.shape[:2]    heightNew = height    widthNew = width    matRotation = cv2.getRotationMatrix2D((width//2,height//2), degree, 1)    imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255))    return imgRotation, matRotationos.makedirs('train', exist_ok=True)src_path = "Label_new.txt"d_path='dd.txt'radom_p=[-3,-2,-1,0,1,2,3,4,5]with open(d_path, 'w') as w_label:    with open(src_path, 'r', encoding='utf-8') as f: for line in f.readlines():     content = line.split('\t')     imag_name = content[0].split('/')[1]     image_path = content[0]     img = cv2.imread(image_path)     list_dict = json.loads(content[1])     nsize = len(list_dict)     num = 0     box = []     info = ''     for i in range(nsize):  lin = list_dict[i]  info = lin['transcription']  info = info.replace(" ", "")  points = lin['points']  points = [list(x) for x in points]  print(points)  box = [b for a in points for b in a]  break     image = cv2.imread(image_path)     for i in range(5):  imgRotation, matRotation = dumpRotateImage(image,randint(-15,15))  # 旋转  imgRotation = Image.fromarray(cv2.cvtColor(imgRotation, cv2.COLOR_BGR2RGB))  imgRotation=transform(imgRotation).convert('RGB')  imgRotation = cv2.cvtColor(np.asarray(imgRotation), cv2.COLOR_RGB2BGR)  pt1 = np.dot(matRotation, np.array([[box[0]], [box[1]], [1]]))  pt2 = np.dot(matRotation, np.array([[box[2]], [box[3]], [1]]))  pt3 = np.dot(matRotation, np.array([[box[4]], [box[5]], [1]]))  pt4 = np.dot(matRotation, np.array([[box[6]], [box[7]], [1]]))  print(int(pt4[0]))  if int(pt1[0])<5 or int(pt1[1])<5 or int(pt2[0])<5 or int(pt2[1])<5 or int(pt3[0])<5 or int(pt3[1])<5 or int(pt4[0])<5 or int(pt4[1])<5:      continue  result_info = [{"transcription": info,    "points": [[int(pt1[0])+radom_p[randint(0, len(radom_p)-1)], int(pt1[1])+radom_p[randint(0, len(radom_p)-1)]], [int(pt2[0])+radom_p[randint(0, len(radom_p)-1)], int(pt2[1])+radom_p[randint(0, len(radom_p)-1)]], [int(pt3[0])+radom_p[randint(0, len(radom_p)-1)], int(pt3[1])+radom_p[randint(0, len(radom_p)-1)]], [int(pt4[0])+radom_p[randint(0, len(radom_p)-1)], int(pt4[1])+radom_p[randint(0, len(radom_p)-1)]]], "difficult": "false"}]  imag_d_path = "train/" + imag_name.split('.')[0] + "_" + str(i) + "_0726." + imag_name.split('.')[      1] + "\t" + str(result_info)+'\n'  print(imag_d_path)  cv2.imwrite('./train/' + imag_name.split('.')[0] + "_" + str(i) + "_0726." + imag_name.split('.')[1],imgRotation)  w_label.write(imag_d_path.replace('\'','\"'))

6、删除没有标注的图片

将标注的数据和图片的列表做差,将多余的图片删除。

import osimage_list = os.listdir('train/')label_txt = "Label_new.txt"label_list=[]with open(label_txt, 'r') as label_s:    for line in label_s.readlines(): label_list.append(line.split('\t')[0].split('/')[1])cha_list=list(set(image_list) - set(label_list))for img in cha_list:    os.remove('train/'+img)