forked from tensorlayer/HyperPose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain.py
49 lines (46 loc) · 1.79 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python3
import os
import cv2
import sys
import math
import json
import time
import argparse
import matplotlib
import multiprocessing
import numpy as np
import tensorflow as tf
import tensorlayer as tl
from hyperpose import Config,Model,Dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='FastPose.')
parser.add_argument("--model_type",
type=str,
default="Openpose",
help="human pose estimation model type, available options: Openpose, LightweightOpenpose ,MobilenetThinOpenpose, PoseProposal")
parser.add_argument("--model_backbone",
type=str,
default="Default",
help="model backbone, available options: Mobilenet, Vggtiny, Vgg19, Resnet18, Resnet50")
parser.add_argument("--model_name",
type=str,
default="default_name",
help="model name,to distinguish model and determine model dir")
parser.add_argument("--dataset_path",
type=str,
default="./data",
help="dataset path,to determine the path to load the dataset")
args=parser.parse_args()
#config model
Config.set_model_name(args.model_name)
Config.set_model_type(Config.MODEL[args.model_type])
Config.set_model_backbone(Config.BACKBONE[args.model_backbone])
Config.set_pretrain(True)
#config dataset
Config.set_pretrain_dataset_path(args.dataset_path)
config=Config.get_config()
#train
model=Model.get_model(config)
pretrain=Model.get_pretrain(config)
dataset=Dataset.get_pretrain_dataset(config)
pretrain(model,dataset)