@@ -1,3 +1,4 @@ | |||||
from __future__ import print_function | |||||
import cv2 | import cv2 | ||||
import time | import time | ||||
import numpy as np | import numpy as np | ||||
@@ -6,7 +7,6 @@ from torch.autograd.variable import Variable | |||||
from dface.core.models import PNet,RNet,ONet | from dface.core.models import PNet,RNet,ONet | ||||
import dface.core.utils as utils | import dface.core.utils as utils | ||||
import dface.core.image_tools as image_tools | import dface.core.image_tools as image_tools | ||||
from __future__ import print_function | |||||
def create_mtcnn_net(p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True): | def create_mtcnn_net(p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True): | ||||
@@ -1,6 +1,7 @@ | |||||
from __future__ import print_function | |||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
from __future__ import print_function | |||||
class ImageDB(object): | class ImageDB(object): | ||||
def __init__(self, image_annotation_file, prefix_path='', mode='train'): | def __init__(self, image_annotation_file, prefix_path='', mode='train'): | ||||
@@ -1,7 +1,8 @@ | |||||
from __future__ import print_function | |||||
import os | import os | ||||
import dface.config as config | import dface.config as config | ||||
import dface.prepare_data.assemble as assemble | import dface.prepare_data.assemble as assemble | ||||
from __future__ import print_function | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -1,7 +1,8 @@ | |||||
from __future__ import print_function | |||||
import os | import os | ||||
import dface.config as config | import dface.config as config | ||||
import dface.prepare_data.assemble as assemble | import dface.prepare_data.assemble as assemble | ||||
from __future__ import print_function | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -1,7 +1,8 @@ | |||||
from __future__ import print_function | |||||
import os | import os | ||||
import dface.config as config | import dface.config as config | ||||
import dface.prepare_data.assemble as assemble | import dface.prepare_data.assemble as assemble | ||||
from __future__ import print_function | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -1,3 +1,4 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
@@ -10,7 +11,7 @@ import cPickle | |||||
from dface.core.utils import convert_to_square,IoU | from dface.core.utils import convert_to_square,IoU | ||||
import dface.config as config | import dface.config as config | ||||
import dface.core.vision as vision | import dface.core.vision as vision | ||||
from __future__ import print_function | |||||
def gen_onet_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False): | def gen_onet_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False): | ||||
@@ -1,3 +1,4 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import numpy as np | import numpy as np | ||||
import cv2 | import cv2 | ||||
@@ -5,7 +6,7 @@ import os | |||||
import numpy.random as npr | import numpy.random as npr | ||||
from dface.core.utils import IoU | from dface.core.utils import IoU | ||||
import dface.config as config | import dface.config as config | ||||
from __future__ import print_function | |||||
def gen_pnet_data(data_dir,anno_file,prefix): | def gen_pnet_data(data_dir,anno_file,prefix): | ||||
@@ -1,3 +1,4 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
@@ -10,7 +11,7 @@ import cPickle | |||||
from dface.core.utils import convert_to_square,IoU | from dface.core.utils import convert_to_square,IoU | ||||
import dface.config as config | import dface.config as config | ||||
import dface.core.vision as vision | import dface.core.vision as vision | ||||
from __future__ import print_function | |||||
def gen_rnet_data(data_dir, anno_file, pnet_model_file, prefix_path='', use_cuda=True, vis=False): | def gen_rnet_data(data_dir, anno_file, pnet_model_file, prefix_path='', use_cuda=True, vis=False): | ||||
@@ -1,4 +1,5 @@ | |||||
# coding: utf-8 | # coding: utf-8 | ||||
from __future__ import print_function | |||||
import os | import os | ||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
@@ -7,7 +8,7 @@ import numpy.random as npr | |||||
import argparse | import argparse | ||||
import dface.config as config | import dface.config as config | ||||
import dface.core.utils as utils | import dface.core.utils as utils | ||||
from __future__ import print_function | |||||
def gen_data(anno_file, data_dir, prefix): | def gen_data(anno_file, data_dir, prefix): | ||||
@@ -1,4 +1,5 @@ | |||||
# coding: utf-8 | # coding: utf-8 | ||||
from __future__ import print_function | |||||
import os | import os | ||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
@@ -8,7 +9,7 @@ import numpy.random as npr | |||||
import argparse | import argparse | ||||
import dface.config as config | import dface.config as config | ||||
import dface.core.utils as utils | import dface.core.utils as utils | ||||
from __future__ import print_function | |||||
def gen_data(anno_file, data_dir, prefix): | def gen_data(anno_file, data_dir, prefix): | ||||
@@ -1,4 +1,5 @@ | |||||
# coding: utf-8 | # coding: utf-8 | ||||
from __future__ import print_function | |||||
import os | import os | ||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
@@ -8,7 +9,7 @@ import numpy.random as npr | |||||
import argparse | import argparse | ||||
import dface.config as config | import dface.config as config | ||||
import dface.core.utils as utils | import dface.core.utils as utils | ||||
from __future__ import print_function | |||||
def gen_data(anno_file, data_dir, prefix): | def gen_data(anno_file, data_dir, prefix): | ||||
@@ -1,5 +1,5 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
from dface.core.detect import MtcnnDetector,create_mtcnn_net | from dface.core.detect import MtcnnDetector,create_mtcnn_net | ||||
@@ -11,7 +11,7 @@ import cPickle | |||||
from dface.core.utils import convert_to_square,IoU | from dface.core.utils import convert_to_square,IoU | ||||
import dface.config as config | import dface.config as config | ||||
import dface.core.vision as vision | import dface.core.vision as vision | ||||
from __future__ import print_function | |||||
def gen_landmark48_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False): | def gen_landmark48_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False): | ||||
@@ -1,3 +1,4 @@ | |||||
from __future__ import print_function | |||||
from dface.core.image_reader import TrainImageReader | from dface.core.image_reader import TrainImageReader | ||||
import datetime | import datetime | ||||
import os | import os | ||||
@@ -5,7 +6,7 @@ from dface.core.models import PNet,RNet,ONet,LossFn | |||||
import torch | import torch | ||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
import dface.core.image_tools as image_tools | import dface.core.image_tools as image_tools | ||||
from __future__ import print_function | |||||
def compute_accuracy(prob_cls, gt_cls): | def compute_accuracy(prob_cls, gt_cls): | ||||
@@ -1,10 +1,11 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import sys | import sys | ||||
from dface.core.imagedb import ImageDB | from dface.core.imagedb import ImageDB | ||||
import dface.train_net.train as train | import dface.train_net.train as train | ||||
import dface.config as config | import dface.config as config | ||||
import os | import os | ||||
from __future__ import print_function | |||||
def train_net(annotation_file, model_store_path, | def train_net(annotation_file, model_store_path, | ||||
@@ -1,10 +1,11 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import sys | import sys | ||||
from dface.core.imagedb import ImageDB | from dface.core.imagedb import ImageDB | ||||
from dface.train_net.train import train_pnet | from dface.train_net.train import train_pnet | ||||
import dface.config as config | import dface.config as config | ||||
import os | import os | ||||
from __future__ import print_function | |||||
def train_net(annotation_file, model_store_path, | def train_net(annotation_file, model_store_path, | ||||
@@ -1,10 +1,11 @@ | |||||
from __future__ import print_function | |||||
import argparse | import argparse | ||||
import sys | import sys | ||||
from dface.core.imagedb import ImageDB | from dface.core.imagedb import ImageDB | ||||
import dface.train_net.train as train | import dface.train_net.train as train | ||||
import dface.config as config | import dface.config as config | ||||
import os | import os | ||||
from __future__ import print_function | |||||
def train_net(annotation_file, model_store_path, | def train_net(annotation_file, model_store_path, | ||||
@@ -10,7 +10,8 @@ if __name__ == '__main__': | |||||
r_model = "./model_store/rnet_epoch.pt" | r_model = "./model_store/rnet_epoch.pt" | ||||
o_model = "./model_store/onet_epoch.pt" | o_model = "./model_store/onet_epoch.pt" | ||||
pnet, rnet, onet = create_mtcnn_net(p_model_path=p_model, r_model_path=r_model, o_model_path=o_model, use_cuda=True) | |||||
#use cpu version set use_cuda=False, if you want to use gpu version set use_cuda=True | |||||
pnet, rnet, onet = create_mtcnn_net(p_model_path=p_model, r_model_path=r_model, o_model_path=o_model, use_cuda=False) | |||||
mtcnn_detector = MtcnnDetector(pnet=pnet, rnet=rnet, onet=onet, min_face_size=24) | mtcnn_detector = MtcnnDetector(pnet=pnet, rnet=rnet, onet=onet, min_face_size=24) | ||||
img = cv2.imread("./test.jpg") | img = cv2.imread("./test.jpg") | ||||
@@ -20,4 +21,4 @@ if __name__ == '__main__': | |||||
bboxs, landmarks = mtcnn_detector.detect_face(img) | bboxs, landmarks = mtcnn_detector.detect_face(img) | ||||
# print box_align | # print box_align | ||||
vision.vis_face(img2,bboxs,landmarks) | |||||
vision.vis_face(img2,bboxs,landmarks) |