Browse Source

fix(data): fix contiguous id

GitOrigin-RevId: 7f79cda0b5
release-0.6
Megvii Engine Team 4 years ago
parent
commit
0df74604bd
3 changed files with 17 additions and 27 deletions
  1. +1
    -1
      python_module/megengine/data/dataset/vision/coco.py
  2. +1
    -1
      python_module/megengine/data/dataset/vision/objects365.py
  3. +15
    -25
      python_module/megengine/data/dataset/vision/voc.py

+ 1
- 1
python_module/megengine/data/dataset/vision/coco.py View File

@@ -118,7 +118,7 @@ class COCO(VisionDataset):
self.ids = ids self.ids = ids


self.json_category_id_to_contiguous_id = { self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.cats.keys())
v: i + 1 for i, v in enumerate(sorted(self.cats.keys()))
} }


self.contiguous_category_id_to_json_id = { self.contiguous_category_id_to_json_id = {


+ 1
- 1
python_module/megengine/data/dataset/vision/objects365.py View File

@@ -81,7 +81,7 @@ class Objects365(VisionDataset):
self.ids = ids self.ids = ids


self.json_category_id_to_contiguous_id = { self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.cats.keys())
v: i + 1 for i, v in enumerate(sorted(self.cats.keys()))
} }


self.contiguous_category_id_to_json_id = { self.contiguous_category_id_to_json_id = {


+ 15
- 25
python_module/megengine/data/dataset/vision/voc.py View File

@@ -75,6 +75,8 @@ class PascalVOC(VisionDataset):
else: else:
raise NotImplementedError raise NotImplementedError


self.img_infos = dict()

def __getitem__(self, index): def __getitem__(self, index):
target = [] target = []
for k in self.order: for k in self.order:
@@ -107,9 +109,8 @@ class PascalVOC(VisionDataset):
mask = mask[:, :, np.newaxis] mask = mask[:, :, np.newaxis]
target.append(mask) target.append(mask)
elif k == "info": elif k == "info":
if image is None:
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
info = [image.shape[0], image.shape[1], self.file_names[index]]
info = self.get_img_info(index, image)
info = [info["height"], info["width"], info["file_name"]]
target.append(info) target.append(info)
else: else:
raise NotImplementedError raise NotImplementedError
@@ -119,6 +120,17 @@ class PascalVOC(VisionDataset):
def __len__(self): def __len__(self):
return len(self.images) return len(self.images)


def get_img_info(self, index, image=None):
if index not in self.img_infos:
if image is None:
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
self.img_infos[index] = dict(
height=image.shape[0],
width=image.shape[1],
file_name=self.file_names[index],
)
return self.img_infos[index]

def _trans_mask(self, mask): def _trans_mask(self, mask):
label = np.ones(mask.shape[:2]) * 255 label = np.ones(mask.shape[:2]) * 255
for i in range(len(self.class_colors)): for i in range(len(self.class_colors)):
@@ -171,25 +183,3 @@ class PascalVOC(VisionDataset):
"train", "train",
"tvmonitor", "tvmonitor",
) )
class_colors = [
[0, 0, 128],
[0, 128, 0],
[0, 128, 128],
[128, 0, 0],
[128, 0, 128],
[128, 128, 0],
[128, 128, 128],
[0, 0, 64],
[0, 0, 192],
[0, 128, 64],
[0, 128, 192],
[128, 0, 64],
[128, 0, 192],
[128, 128, 64],
[128, 128, 192],
[0, 64, 0],
[0, 64, 128],
[0, 192, 0],
[0, 192, 128],
[128, 64, 0],
]

Loading…
Cancel
Save