You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prepare_caltech101.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import os, sys
  2. from glob import glob
  3. import random
  4. import argparse
  5. from PIL import Image
  6. if __name__=='__main__':
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('--data_root', type=str, default='../101_ObjectCategories')
  9. parser.add_argument('--test_split', type=float, default=0.3)
  10. args = parser.parse_args()
  11. SAVE_DIR = os.path.join( os.path.dirname(args.data_root), 'caltech101_data' )
  12. if not os.path.exists(SAVE_DIR):
  13. os.mkdir(SAVE_DIR)
  14. # Train
  15. TRAIN_DIR = os.path.join( SAVE_DIR, 'train' )
  16. if not os.path.exists(TRAIN_DIR):
  17. os.mkdir(TRAIN_DIR)
  18. # Test
  19. TEST_DIR = os.path.join( SAVE_DIR, 'test' )
  20. if not os.path.exists(TEST_DIR):
  21. os.mkdir(TEST_DIR)
  22. img_folders = os.listdir(args.data_root)
  23. img_folders.sort()
  24. for folder in img_folders:
  25. if folder=='Faces':
  26. continue
  27. print('Processing %s'%(folder))
  28. img_paths = glob(os.path.join( args.data_root, folder, '*.jpg') )
  29. img_name = [os.path.split(p)[-1] for p in img_paths]
  30. random.shuffle(img_name)
  31. img_n = len(img_name)
  32. test_n = int(args.test_split * img_n)
  33. test_set = img_name[:test_n]
  34. train_set = img_name[test_n:]
  35. # test
  36. dst_path = os.path.join(TEST_DIR, folder)
  37. if not os.path.exists(dst_path):
  38. os.mkdir(dst_path)
  39. for test_name in test_set:
  40. img = Image.open(os.path.join( args.data_root, folder, test_name ))
  41. img.save( os.path.join(dst_path, test_name ) )
  42. # train
  43. dst_path = os.path.join(TRAIN_DIR, folder)
  44. if not os.path.exists(dst_path):
  45. os.mkdir(dst_path)
  46. for train_name in train_set:
  47. img = Image.open(os.path.join( args.data_root, folder, train_name ))
  48. img.save( os.path.join(dst_path, train_name ) )

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)