热门IT资讯网

加速,改成多线程转VOC矩形框和生成tfrecord

发表于:2024-11-26 作者:热门IT资讯网编辑
编辑最后更新 2024年11月26日,改生成VOC2007矩形框:# -*- coding: utf-8 -*-"""Created on Thu Sep 19 14:51:00 2019@author: Andrea"""import

  改生成VOC2007矩形框:

  # -*- coding: utf-8 -*-

  """

  Created on Thu Sep 19 14:51:00 2019

  @author: Andrea

  """

  import os

  import numpy as np

  import codecs

  import json

  from glob import glob

  import cv2

  import shutil

  from sklearn.model_selection import train_test_split

  import threading

  #1.标签路径

  labelme_path = "I:\\biaozhutuxiang\\fangdichan1106-banannanan" #原始labelme标注数据路径

  saved_path = "I:\\biaozhutuxiang\\VOC2007-fangdichan1106-banannanan\\" #保存路径

  #2.创建要求文件夹

  if not os.path.exists(saved_path + "Annotations"):

  os.makedirs(saved_path + "Annotations")

  if not os.path.exists(saved_path + "JPEGImages/"):

  os.makedirs(saved_path + "JPEGImages/")

  if not os.path.exists(saved_path + "ImageSets/Main/"):

  os.makedirs(saved_path + "ImageSets/Main/")

  """重新定义带返回值的线程类----民国档案------"""

  class LoadThread(threading.Thread):

  #class LoadThread_rep:

  def __init__(self, json_file_):

  super(LoadThread, self).__init__()

  self.json_file_ = json_file_

  def run(self):

  if('.json' not in self.json_file_):

  return self.json_file_

  else:

  json_file_ = self.json_file_.split('.json')[0]

  print(json_file_)

  json_filename = os.path.join(labelme_path , json_file_ + ".json")

  print(json_filename)

  json_file = json.load(open(json_filename,"r",encoding="utf-8"))

  print(os.path.join(labelme_path , json_file_ +".jpg"))

  height, width, channels = cv2.imread(os.path.join(labelme_path , json_file_ +".jpg")).shape

  with codecs.open(saved_path + "Annotations/"+json_file_ + ".xml","w","utf-8") as xml:

  xml.write('\n')

  xml.write('\t' + 'UAV_data' + '\n')

  xml.write('\t' + json_file_ + ".jpg" + '\n')

  xml.write('\t\n')

  xml.write('\t\tThe UAV autolanding\n')

  xml.write('\t\tUAV AutoLanding\n')

  xml.write('\t\tflickr\n')

  xml.write('\t\tNULL\n')

  xml.write('\t\n')

  xml.write('\t\n')

  xml.write('\t\tNULL\n')

  xml.write('\t\tYuanyiqin\n')

  xml.write('\t\n')

  xml.write('\t\n')

  xml.write('\t\t'+ str(width) + '\n')

  xml.write('\t\t'+ str(height) + '\n')

  xml.write('\t\t' + str(channels) + '\n')

  xml.write('\t\n')

  xml.write('\t\t0\n')

  for multi in json_file["shapes"]:

  points = np.array(multi["points"])

  xmin = min(points[:,0])

  xmax = max(points[:,0])

  ymin = min(points[:,1])

  ymax = max(points[:,1])

  label = multi["label"]

  if xmax <= xmin:

  pass

  elif ymax <= ymin:

  pass

  else:

  xml.write('\t\n')

  print(json_filename,xmin,ymin,xmax,ymax,label)

  xml.write('')

  self.json_file_

  def get_result(self):

  return self.json_file_

  ##3.获取待处理文件

  #files = glob(labelme_path + "*.json")

  #print(files)

  #files = [i.split("/")[-1].split(".json")[0] for i in files]

  #4.读取标注信息并写入 xml

  threadnum = 64

  if __name__ == '__main__':

  # for json_file_ in os.listdir(labelme_path):

  img_list = os.listdir(labelme_path)

  img_length = len(img_list)

  # threadnum = 4

  for i in range(0,int(img_length/threadnum)+1):

  # for i in range(int(img_length/threadnum)+1):

  print('i,int(img_length/threadnum):',i,int(img_length/threadnum))

  li = []

  for j in range(i*threadnum,min(i*threadnum+threadnum,img_length)):

  # for j in range(i*threadnum,min(i*threadnum+threadnum,img_length)):

  json_file_ = img_list[j]

  print('json_file_:',json_file_)

  thread = LoadThread(json_file_)

  li.append(thread)

  thread.start()

  for thread in li:

  thread.join() # 一定要join,不然主线程比子线程跑的快,会拿不到结果

  json_file_ = thread.get_result()

  print('Down json_file_:',json_file_)

  #5.复制图片到 VOC2007/JPEGImages/下

  image_files = glob(labelme_path + "*.jpg")

  print("copy image files to VOC007/JPEGImages/")

  for image in image_files:

  shutil.copy(image,saved_path +"JPEGImages/")

  #6.split files for txt

  txtsavepath = saved_path + "ImageSets/Main/"

  ftrainval = open(txtsavepath+'/trainval.txt', 'w')

  ftest = open(txtsavepath+'/test.txt', 'w')

  ftrain = open(txtsavepath+'/train.txt', 'w')

  fval = open(txtsavepath+'/val.txt', 'w')

  total_files = glob("./VOC2007/Annotations/*.xml")

  total_files = [i.split("/")[-1].split(".xml")[0] for i in total_files]

  #test_filepath = ""

  for file in total_files:

  ftrainval.write(file + "\n")

  #test

  #for file in os.listdir(test_filepath):

  # ftest.write(file.split(".jpg")[0] + "\n")

  #split

  train_files,val_files = train_test_split(total_files,test_size=0.15,random_state=42)

  #train

  for file in train_files:

  ftrain.write(file + "\n")

  #val

  for file in val_files:

  fval.write(file + "\n")

  ftrainval.close()

  ftrain.close()

  fval.close()

  #ftest.close()

  改成多线程生成tfrecord:

  # -*- coding: utf-8 -*-

  from __future__ import division, print_function, absolute_import

  import sys

  sys.path.append('../../')

  import xml.etree.cElementTree as ET

  import numpy as np

  import tensorflow as tf

  import math

  import glob

  import cv2

  from libs.label_name_dict.label_dict import *

  from help_utils.tools import *

  import threading

  import random

  tf.app.flags.DEFINE_string('VOC_dir', '/home/yuanyq/Detect_DL/FPN_Tensorflow/data/io/VOC2007/', 'Voc dir')

  tf.app.flags.DEFINE_string('xml_dir', 'Annotations', 'xml dir')

  tf.app.flags.DEFINE_string('image_dir', 'JPEGImages', 'image dir')

  tf.app.flags.DEFINE_string('save_name', 'train', 'save name')

  tf.app.flags.DEFINE_string('save_dir', '../tfrecord/', 'save name')

  tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')

  tf.app.flags.DEFINE_string('dataset', 'pascal', 'dataset')

  FLAGS = tf.app.flags.FLAGS

  threadnum = 128

  global count

  count = 0

  class LoadThread(threading.Thread):

  def __init__(self,xml,image_path,xml_path,writer):

  super(LoadThread,self).__init__()

  self.xml = xml

  self.image_path = image_path

  self.xml_path = xml_path

  self.writer = writer

  def run(self):

  # to avoid path error in different development platform

  xml = self.xml.replace('\\', '/')

  img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format

  img_path = self.image_path + '/' + img_name

  print('xml:',xml)

  if not os.path.exists(img_path):

  print('{} is not exist!'.format(img_path))

  #return self.xml

  img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)

  # img = np.array(Image.open(img_path))

  img = cv2.imread(img_path)[:, :, ::-1]

  feature = tf.train.Features(feature={

  # do not need encode() in linux

  'img_name': _bytes_feature(img_name.encode()),

  # 'img_name': _bytes_feature(img_name),

  'img_height': _int64_feature(img_height),

  'img_width': _int64_feature(img_width),

  'img': _bytes_feature(img.tostring()),

  'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),

  'num_objects': _int64_feature(gtbox_label.shape[0])

  })郑州妇科医院 http://www.120zzzy.com/

  example = tf.train.Example(features=feature)

  self.writer.write(example.SerializeToString())

  #view_bar('Conversion progress', count + 1, len(glob.glob(self.xml_path + '/*.xml')))

  return self.xml

  def get_result(self):

  print(self.xml)

  return self.xml

  def _int64_feature(value):

  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

  def _bytes_feature(value):

  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

  def read_xml_gtbox_and_label(xml_path):

  """

  :param xml_path: the path of voc xml

  :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],

  and has [xmin, ymin, xmax, ymax, label] in a per row

  """

  tree = ET.parse(xml_path)

  root = tree.getroot()

  img_width = None

  img_height = None

  box_list = []

  for child_of_root in root:

  # if child_of_root.tag == 'filename':

  # assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \

  # + FLAGS.img_format, 'xml_name and img_name cannot match'

  if child_of_root.tag == 'size':

  for child_item in child_of_root:

  if child_item.tag == 'width':

  img_width = int(child_item.text)

  if child_item.tag == 'height':

  img_height = int(child_item.text)

  if child_of_root.tag == 'object':

  label = None

  for child_item in child_of_root:

  # print('child_item.tag:',child_item.tag)

  # print('child_item.text:',child_item.text)

  # print('NAME_LABEL_MAP:',NAME_LABEL_MAP)

  if child_item.tag == 'name':

  if(child_item.text == '0002X'):

  child_item.text = '0002'

  if(child_item.text == 'X0002'):

  child_item.text = '0002'

  if(child_item.text =='000Z1'):

  child_item.text = '0001'

  if(child_item.text =='A0001'):

  child_item.text = '0001'

  if(child_item.text =='c0002'):

  child_item.text = '0002'

  if(child_item.text !='0001' and child_item.text !='0002' and child_item.text !='0003'):

  label = 1

  else:

  label = NAME_LABEL_MAP[child_item.text]

  if child_item.tag == 'bndbox':

  tmp_box = []

  for node in child_item:

  tmp_box.append(math.ceil(float(node.text)))

  assert label is not None, 'label is none, error'

  tmp_box.append(label)

  box_list.append(tmp_box)

  gtbox_label = np.array(box_list, dtype=np.int32)

  return img_height, img_width, gtbox_label

  def convert_pascal_to_tfrecord():

  xml_path = FLAGS.VOC_dir + FLAGS.xml_dir

  image_path = FLAGS.VOC_dir + FLAGS.image_dir

  save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'

  mkdir(FLAGS.save_dir)

  # print('xml_path:',xml_path)

  # print('save_path:',save_path)

  # print('image_path:',image_path)

  # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)

  # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)

  writer = tf.python_io.TFRecordWriter(path=save_path)

  img_list = os.listdir(xml_path)

  random.shuffle(img_list)

  img_length = len(img_list)

  for i in range(0,int(img_length/threadnum)+1):

  li = []

  for j in range(i*threadnum,min(i*threadnum+threadnum,img_length)):

  xml = os.path.join(xml_path,img_list[j])

  thread = LoadThread(xml,image_path,xml_path,writer)

  thread.daemon = True

  li.append(thread)

  thread.start()

  for thread in li:

  thread.join() # 一定要join,不然主线程比子线程跑的快,会拿不到结果

  xml = thread.get_result()

  print('img_name done:',xml)

  # to avoid path error in different development platform

  print('\nConversion is complete!')

  if __name__ == '__main__':

  # xml_path = '../data/dataset/VOCdevkit/VOC2007/Annotations/000005.xml'

  # read_xml_gtbox_and_label(xml_path)

  convert_pascal_to_tfrecord()


0