|
|
@@ -105,13 +105,12 @@ def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_siz |
|
|
|
data_set = data_set.map(input_columns="label", operations=type_cast_op) |
|
|
|
data_set = data_set.map(input_columns="image", operations=c_trans) |
|
|
|
|
|
|
|
# apply repeat operations |
|
|
|
data_set = data_set.repeat(repeat_num) |
|
|
|
|
|
|
|
# apply shuffle operations |
|
|
|
# data_set = data_set.shuffle(buffer_size=1000) |
|
|
|
data_set = data_set.shuffle(buffer_size=1000) |
|
|
|
|
|
|
|
# apply batch operations |
|
|
|
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) |
|
|
|
|
|
|
|
# apply repeat operations |
|
|
|
data_set = data_set.repeat(repeat_num) |
|
|
|
return data_set |