diff --git a/README.md b/README.md index 46e4b0b..38b0555 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ To test LCNN on your own images, you need download the pre-trained models and ex ```Bash python ./demo.py -d 0 config/wireframe.yaml +python ./demo.py -d 0 config/wireframe.yaml /home/zengxh/workspace/lcnn/config/190418-201834-f8934c6-lr4d10-312k.pth.tar /home/zengxh/workspace/lcnn/data/wireframe/train/00559828_3.png ``` Here, `-d 0` is specifying the GPU ID used for evaluation, and you can specify `-d ""` to force CPU inference. @@ -200,3 +201,10 @@ If you find L-CNN useful in your research, please consider citing: year={2019} } ``` + +/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/dataset/train_test_split.py +/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/dataset/wireframe.py /home/zengxh/datasets/creepageDistance /home/zengxh/datasets/creepageDistance_wireframe +/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/train.py -d 0 --identifier baseline config/wireframe.yaml +/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/demo.py -d 0 config/wireframe.yaml /home/zengxh/workspace/lcnn/logs/210706-112447-88f281a-baseline/checkpoint_best.pth /home/zengxh/datasets/creepageDistance_wireframe/valid/7507639237000304_0_t_0.png + + diff --git a/config/wireframe.yaml b/config/wireframe.yaml index 1391bf4..2c12d66 100644 --- a/config/wireframe.yaml +++ b/config/wireframe.yaml @@ -1,6 +1,6 @@ io: logdir: logs/ - datadir: data/wireframe/ + datadir: /home/zengxh/datasets/creepageDistance_wireframe/ resume_from: num_workers: 4 tensorboard_port: 0 diff --git a/dataset/constants.py b/dataset/constants.py new file mode 100644 index 0000000..fd82d7d --- /dev/null +++ b/dataset/constants.py @@ -0,0 +1,17 @@ +# !/usr/bin/env python +# -- coding: utf-8 -- +# @Author zengxiaohui +# Datatime:7/6/2021 10:51 AM +# @File:constants.py +# 用于归一化的宽度值 +NORMALIZATION_WIDTH = 64 +NORMALIZATION_HEIGHT = 512 +# 像素最大值为255 +PIXS_MAX_VALUE = 255.0 +# 数据类型 +TB_DATATYPE = "tb" +LR_DATATYPE = "lr" +# 准确率容错距离 +ACC_PX_THRESH=16 +# 随机种子 +RANDOM_SEED = 1024 \ No newline at end of file diff --git a/dataset/creepageDistance.yaml b/dataset/creepageDistance.yaml new file mode 100644 index 0000000..1b0525f --- /dev/null +++ b/dataset/creepageDistance.yaml @@ -0,0 +1,9 @@ +datasets: + lr: + allDatas: /home/zengxh/medias/data/ext/creepageDistance/datasets/lr/org + # h*w + img_size: [512,64] + tb: + allDatas: /home/zengxh/medias/data/ext/creepageDistance/datasets/tb/org + # h*w + img_size: [512,64] \ No newline at end of file diff --git a/dataset/train_test_split.py b/dataset/train_test_split.py new file mode 100644 index 0000000..be96cde --- /dev/null +++ b/dataset/train_test_split.py @@ -0,0 +1,168 @@ +import argparse +import os +import shutil +import random + +import cv2 +import yaml +from imutils import paths +from sklearn.model_selection import train_test_split +from tqdm import tqdm + +from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT, RANDOM_SEED +from python_developer_tools.cv.datasets.datasets_utils import resize_image, letterbox +from python_developer_tools.files.common import mkdir, get_filename_suf_pix +from python_developer_tools.files.json_utils import read_json_file, save_json_file + +def createDatasets(datasets, dirname): + dataDir = os.path.join(data_dict_tmp[dirname]) + if not os.path.exists(dataDir): + os.makedirs(dataDir) + for datapath in datasets: + shutil.copy(datapath.replace(".jpg",".json"), dataDir) + shutil.copy(datapath, dataDir) + +def get_origin_image_points(imagePath): + img = cv2.imread(imagePath) + jsonfile = imagePath.replace(".jpg", ".json") + json_cont = read_json_file(jsonfile) + labels_tmp = [0, 0, 0, 0, 0, 0, 0, 0] + for shapes in json_cont["shapes"]: + label = shapes["label"] + points = shapes["points"] + if json_cont["imageHeight"] > json_cont["imageWidth"]: + if label == "min": + if points[0][1] < points[1][1]: + labels_tmp[4] = points[0][0] + labels_tmp[5] = points[0][1] + labels_tmp[6] = points[1][0] + labels_tmp[7] = points[1][1] + else: + labels_tmp[4] = points[1][0] + labels_tmp[5] = points[1][1] + labels_tmp[6] = points[0][0] + labels_tmp[7] = points[0][1] + if label == "max": + if points[0][1] < points[1][1]: + labels_tmp[0] = points[0][0] + labels_tmp[1] = points[0][1] + labels_tmp[2] = points[1][0] + labels_tmp[3] = points[1][1] + else: + labels_tmp[0] = points[1][0] + labels_tmp[1] = points[1][1] + labels_tmp[2] = points[0][0] + labels_tmp[3] = points[0][1] + else: + if label == "min": + if points[0][0] < points[1][0]: + labels_tmp[4] = points[0][0] + labels_tmp[5] = points[0][1] + labels_tmp[6] = points[1][0] + labels_tmp[7] = points[1][1] + else: + labels_tmp[4] = points[1][0] + labels_tmp[5] = points[1][1] + labels_tmp[6] = points[0][0] + labels_tmp[7] = points[0][1] + if label == "max": + if points[0][0] < points[1][0]: + labels_tmp[0] = points[0][0] + labels_tmp[1] = points[0][1] + labels_tmp[2] = points[1][0] + labels_tmp[3] = points[1][1] + else: + labels_tmp[0] = points[1][0] + labels_tmp[1] = points[1][1] + labels_tmp[2] = points[0][0] + labels_tmp[3] = points[0][1] + return img,labels_tmp + +def label_transpose_1(label_o,w0,h0): + # tb 顺时针旋转90° + new_label = [0, 0, 0, 0, 0, 0, 0, 0] + new_label[0] = w0-label_o[5] + new_label[1] = label_o[4] + new_label[2] = w0-label_o[7] + new_label[3] = label_o[6] + new_label[4] = w0-label_o[1] + new_label[5] = label_o[0] + new_label[6] = w0-label_o[3] + new_label[7] = label_o[2] + return new_label + +def labels_convert_train(label,w0,h0,w1,h1,w2,h2,padw,padh): + new_label = [0,0,0,0,0,0,0,0] + for i,_label in enumerate(label): + if i in [0,2,4,6]: + new_label[i] = ((_label * w1 / w0) * w2 ) / w1 + padw + else: + new_label[i] = ((_label * h1 / h0) * h2 ) / h1 + padh + # label = [i / w0 for i in label] + # label = [i * w1 / w0 for i in label] + # label = [(i * w2+ padw) / w1 for i in label] + # label = [i / NORMALIZATION_WIDTH for i in label] + return new_label + +def get_dict_json(imagePath): + filename, filedir, filesuffix, filenamestem = get_filename_suf_pix(imagePath) + img, labels_tmp = get_origin_image_points(imagePath) + if key == "tb": + img = cv2.transpose(img) + img = cv2.flip(img, 1) + h0, w0 = img.shape[:2] # orig hw + labels_tmp = label_transpose_1(labels_tmp, w0, h0) + # _ = cv2.line(img, (int(labels_tmp[0]), int(labels_tmp[1])), (int(labels_tmp[2]), int(labels_tmp[3])), + # (0, 255, 0), thickness=2) + # _ = cv2.line(img, (int(labels_tmp[4]), int(labels_tmp[5])), (int(labels_tmp[6]), int(labels_tmp[7])), + # (0, 0, 255), thickness=2) + # cv2.imwrite("sdf.jpg", img) + h0, w0 = img.shape[:2] + + img = resize_image(img, [NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH]) + h1, w1, _ = img.shape + img, ratio, pad = letterbox(img, [NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH], auto=False, scaleup=True) + #letterbox(img, img_size, auto=False,scaleFill=True) # 会填充边缘 letterbox(img, self.opt.img_size, auto=False, scaleup=False) + labels_tmp = labels_convert_train(labels_tmp, w0, h0, w1, h1, ratio[0] * w1, ratio[1] * h1, pad[0], pad[1]) + + h2, w2 = img.shape[:2] + dict_json = {"filename": filename, + "lines": [[labels_tmp[0], labels_tmp[1], labels_tmp[2], labels_tmp[3]], + [labels_tmp[4], labels_tmp[5], labels_tmp[6], labels_tmp[7]]], + "height": h2, "width": w2} + cv2.imwrite(os.path.join(images_dir, filename), img) + return dict_json + +if __name__ == '__main__': + random.seed(RANDOM_SEED) + parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图") + parser.add_argument('--data', + default=r"creepageDistance.yaml", + help="没有分的文件夹") + parser.add_argument('--datasets_path', + default=r"/home/zengxh/datasets/creepageDistance", + help="没有分的文件夹") + opt = parser.parse_args() + + images_dir = os.path.join(opt.datasets_path,"images") + mkdir(images_dir) + + with open(opt.data,encoding="utf-8") as f: + data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict + + train_json = [] + valid_json = [] + for (key, data_dict_tmp) in data_dict["datasets"].items(): + nameImgs = list(paths.list_images(os.path.join(data_dict_tmp["allDatas"]))) + X_train, X_test_val, _, _ = train_test_split(nameImgs, nameImgs, test_size=0.2, random_state=RANDOM_SEED) + + for imagePath in X_train: + dict_json = get_dict_json(imagePath) + train_json.append(dict_json) + + for imagePath in X_test_val: + dict_json = get_dict_json(imagePath) + valid_json.append(dict_json) + + save_json_file(os.path.join(opt.datasets_path,"train.json"),train_json) + save_json_file(os.path.join(opt.datasets_path,"valid.json"),valid_json) \ No newline at end of file diff --git a/dataset/wireframe.py b/dataset/wireframe.py index e46529b..bc272f7 100755 --- a/dataset/wireframe.py +++ b/dataset/wireframe.py @@ -27,6 +27,8 @@ from docopt import docopt from scipy.ndimage import zoom +from dataset.constants import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH + try: sys.path.append(".") sys.path.append("..") @@ -44,13 +46,13 @@ def to_int(x): def save_heatmap(prefix, image, lines): - im_rescale = (512, 512) - heatmap_scale = (128, 128) + im_rescale = (NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT) + heatmap_scale = (int(NORMALIZATION_WIDTH / 4), int(NORMALIZATION_HEIGHT / 4)) fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1] - jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32) - joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32) - lmap = np.zeros(heatmap_scale, dtype=np.float32) + jmap = np.zeros((1,) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32) + joff = np.zeros((1, 2) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32) + lmap = np.zeros((heatmap_scale[1],heatmap_scale[0]), dtype=np.float32) lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4) lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4) diff --git a/dataset/york.py b/dataset/york.py index 54d850a..5ffd9e1 100755 --- a/dataset/york.py +++ b/dataset/york.py @@ -30,6 +30,8 @@ from scipy.io import loadmat from scipy.ndimage import zoom +from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT + try: sys.path.append(".") sys.path.append("..") @@ -47,13 +49,13 @@ def to_int(x): def save_heatmap(prefix, image, lines): - im_rescale = (512, 512) - heatmap_scale = (128, 128) + im_rescale = (NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT) + heatmap_scale = (int(NORMALIZATION_WIDTH / 4), int(NORMALIZATION_HEIGHT / 4)) fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1] - jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32) - joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32) - lmap = np.zeros(heatmap_scale, dtype=np.float32) + jmap = np.zeros((1,) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32) + joff = np.zeros((1, 2) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32) + lmap = np.zeros((heatmap_scale[1],heatmap_scale[0]), dtype=np.float32) lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4) lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4) diff --git a/demo.py b/demo.py index 1f20af1..41f059c 100755 --- a/demo.py +++ b/demo.py @@ -29,6 +29,7 @@ from docopt import docopt import lcnn +from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT from lcnn.config import C, M from lcnn.models.line_vectorizer import LineVectorizer from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner @@ -89,7 +90,8 @@ def main(): if im.ndim == 2: im = np.repeat(im[:, :, None], 3, 2) im = im[:, :, :3] - im_resized = skimage.transform.resize(im, (512, 512)) * 255 + im_resized = skimage.transform.resize(im, (NORMALIZATION_HEIGHT,NORMALIZATION_WIDTH )) * 255 + # skimage.io.imsave('cat.jpg', im_resized) image = (im_resized - M.image.mean) / M.image.stddev image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float() with torch.no_grad(): @@ -104,14 +106,14 @@ def main(): } ], "target": { - "jmap": torch.zeros([1, 1, 128, 128]).to(device), - "joff": torch.zeros([1, 1, 2, 128, 128]).to(device), + "jmap": torch.zeros([1, 1, int(NORMALIZATION_HEIGHT / 4), int(NORMALIZATION_WIDTH / 4)]).to(device), + "joff": torch.zeros([1, 1, 2, int(NORMALIZATION_HEIGHT / 4), int(NORMALIZATION_WIDTH / 4)]).to(device), }, "mode": "testing", } H = model(input_dict)["preds"] - lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] + lines = H["lines"][0].cpu().numpy() / (int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4)) * im.shape[:2] scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): @@ -122,23 +124,23 @@ def main(): # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) - - for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]): - plt.gca().set_axis_off() - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - plt.margins(0, 0) - for (a, b), s in zip(nlines, nscores): - if s < t: - continue - plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) - plt.scatter(a[1], a[0], **PLTOPTS) - plt.scatter(b[1], b[0], **PLTOPTS) - plt.gca().xaxis.set_major_locator(plt.NullLocator()) - plt.gca().yaxis.set_major_locator(plt.NullLocator()) - plt.imshow(im) - plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight") - plt.show() - plt.close() + print(nlines) + # for i, t in enumerate([0.01, 0.95, 0.96, 0.97, 0.98, 0.99]): + plt.gca().set_axis_off() + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + for (a, b), s in zip(nlines, nscores): + # if s < t: + # continue + plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) + plt.scatter(a[1], a[0], **PLTOPTS) + plt.scatter(b[1], b[0], **PLTOPTS) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.imshow(im) + plt.savefig(imname.replace(".png", f"-{0.1:.02f}.svg"), bbox_inches="tight") + plt.show() + plt.close() if __name__ == "__main__": diff --git a/deploy/creepageDistanceModel/config.py b/deploy/creepageDistanceModel/config.py new file mode 100644 index 0000000..a074b3c --- /dev/null +++ b/deploy/creepageDistanceModel/config.py @@ -0,0 +1,1112 @@ +import numpy as np + +import string +import sys +import json +import re +import copy +from keyword import kwlist +import warnings + + +try: + from collections.abc import Iterable, Mapping, Callable +except ImportError: + from collections import Iterable, Mapping, Callable + +yaml_support = True + +try: + import yaml +except ImportError: + try: + import ruamel.yaml as yaml + except ImportError: + yaml = None + yaml_support = False + +if sys.version_info >= (3, 0): + basestring = str +else: + from io import open + +__all__ = ['Box', 'ConfigBox', 'BoxList', 'SBox', + 'BoxError', 'BoxKeyError'] +__author__ = 'Chris Griffith' +__version__ = '3.2.4' + +BOX_PARAMETERS = ('default_box', 'default_box_attr', 'conversion_box', + 'frozen_box', 'camel_killer_box', 'box_it_up', + 'box_safe_prefix', 'box_duplicates', 'ordered_box') + +_first_cap_re = re.compile('(.)([A-Z][a-z]+)') +_all_cap_re = re.compile('([a-z0-9])([A-Z])') + + +class BoxError(Exception): + """Non standard dictionary exceptions""" + + +class BoxKeyError(BoxError, KeyError, AttributeError): + """Key does not exist""" + + +# Abstract converter functions for use in any Box class + + +def _to_json(obj, filename=None, + encoding="utf-8", errors="strict", **json_kwargs): + json_dump = json.dumps(obj, + ensure_ascii=False, **json_kwargs) + if filename: + with open(filename, 'w', encoding=encoding, errors=errors) as f: + f.write(json_dump if sys.version_info >= (3, 0) else + json_dump.decode("utf-8")) + else: + return json_dump + + +def _from_json(json_string=None, filename=None, + encoding="utf-8", errors="strict", multiline=False, **kwargs): + if filename: + with open(filename, 'r', encoding=encoding, errors=errors) as f: + if multiline: + data = [json.loads(line.strip(), **kwargs) for line in f + if line.strip() and not line.strip().startswith("#")] + else: + data = json.load(f, **kwargs) + elif json_string: + data = json.loads(json_string, **kwargs) + else: + raise BoxError('from_json requires a string or filename') + return data + + +def _to_yaml(obj, filename=None, default_flow_style=False, + encoding="utf-8", errors="strict", + **yaml_kwargs): + if filename: + with open(filename, 'w', + encoding=encoding, errors=errors) as f: + yaml.dump(obj, stream=f, + default_flow_style=default_flow_style, + **yaml_kwargs) + else: + return yaml.dump(obj, + default_flow_style=default_flow_style, + **yaml_kwargs) + + +def _from_yaml(yaml_string=None, filename=None, + encoding="utf-8", errors="strict", + **kwargs): + if filename: + with open(filename, 'r', + encoding=encoding, errors=errors) as f: + data = yaml.load(f, **kwargs) + elif yaml_string: + data = yaml.load(yaml_string, **kwargs) + else: + raise BoxError('from_yaml requires a string or filename') + return data + + +# Helper functions + + +def _safe_key(key): + try: + return str(key) + except UnicodeEncodeError: + return key.encode("utf-8", "ignore") + + +def _safe_attr(attr, camel_killer=False, replacement_char='x'): + """Convert a key into something that is accessible as an attribute""" + allowed = string.ascii_letters + string.digits + '_' + + attr = _safe_key(attr) + + if camel_killer: + attr = _camel_killer(attr) + + attr = attr.replace(' ', '_') + + out = '' + for character in attr: + out += character if character in allowed else "_" + out = out.strip("_") + + try: + int(out[0]) + except (ValueError, IndexError): + pass + else: + out = '{0}{1}'.format(replacement_char, out) + + if out in kwlist: + out = '{0}{1}'.format(replacement_char, out) + + return re.sub('_+', '_', out) + + +def _camel_killer(attr): + """ + CamelKiller, qu'est-ce que c'est? + + Taken from http://stackoverflow.com/a/1176023/3244542 + """ + try: + attr = str(attr) + except UnicodeEncodeError: + attr = attr.encode("utf-8", "ignore") + + s1 = _first_cap_re.sub(r'\1_\2', attr) + s2 = _all_cap_re.sub(r'\1_\2', s1) + return re.sub('_+', '_', s2.casefold() if hasattr(s2, 'casefold') else + s2.lower()) + + +def _recursive_tuples(iterable, box_class, recreate_tuples=False, **kwargs): + out_list = [] + for i in iterable: + if isinstance(i, dict): + out_list.append(box_class(i, **kwargs)) + elif isinstance(i, list) or (recreate_tuples and isinstance(i, tuple)): + out_list.append(_recursive_tuples(i, box_class, + recreate_tuples, **kwargs)) + else: + out_list.append(i) + return tuple(out_list) + + +def _conversion_checks(item, keys, box_config, check_only=False, + pre_check=False): + """ + Internal use for checking if a duplicate safe attribute already exists + + :param item: Item to see if a dup exists + :param keys: Keys to check against + :param box_config: Easier to pass in than ask for specfic items + :param check_only: Don't bother doing the conversion work + :param pre_check: Need to add the item to the list of keys to check + :return: the original unmodified key, if exists and not check_only + """ + if box_config['box_duplicates'] != 'ignore': + if pre_check: + keys = list(keys) + [item] + + key_list = [(k, + _safe_attr(k, camel_killer=box_config['camel_killer_box'], + replacement_char=box_config['box_safe_prefix'] + )) for k in keys] + if len(key_list) > len(set(x[1] for x in key_list)): + seen = set() + dups = set() + for x in key_list: + if x[1] in seen: + dups.add("{0}({1})".format(x[0], x[1])) + seen.add(x[1]) + if box_config['box_duplicates'].startswith("warn"): + warnings.warn('Duplicate conversion attributes exist: ' + '{0}'.format(dups)) + else: + raise BoxError('Duplicate conversion attributes exist: ' + '{0}'.format(dups)) + if check_only: + return + # This way will be slower for warnings, as it will have double work + # But faster for the default 'ignore' + for k in keys: + if item == _safe_attr(k, camel_killer=box_config['camel_killer_box'], + replacement_char=box_config['box_safe_prefix']): + return k + + +def _get_box_config(cls, kwargs): + return { + # Internal use only + '__converted': set(), + '__box_heritage': kwargs.pop('__box_heritage', None), + '__created': False, + '__ordered_box_values': [], + # Can be changed by user after box creation + 'default_box': kwargs.pop('default_box', False), + 'default_box_attr': kwargs.pop('default_box_attr', cls), + 'conversion_box': kwargs.pop('conversion_box', True), + 'box_safe_prefix': kwargs.pop('box_safe_prefix', 'x'), + 'frozen_box': kwargs.pop('frozen_box', False), + 'camel_killer_box': kwargs.pop('camel_killer_box', False), + 'modify_tuples_box': kwargs.pop('modify_tuples_box', False), + 'box_duplicates': kwargs.pop('box_duplicates', 'ignore'), + 'ordered_box': kwargs.pop('ordered_box', False) + } + + +class Box(dict): + """ + Improved dictionary access through dot notation with additional tools. + + :param default_box: Similar to defaultdict, return a default value + :param default_box_attr: Specify the default replacement. + WARNING: If this is not the default 'Box', it will not be recursive + :param frozen_box: After creation, the box cannot be modified + :param camel_killer_box: Convert CamelCase to snake_case + :param conversion_box: Check for near matching keys as attributes + :param modify_tuples_box: Recreate incoming tuples with dicts into Boxes + :param box_it_up: Recursively create all Boxes from the start + :param box_safe_prefix: Conversion box prefix for unsafe attributes + :param box_duplicates: "ignore", "error" or "warn" when duplicates exists + in a conversion_box + :param ordered_box: Preserve the order of keys entered into the box + """ + + _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml', + 'from_yaml', 'from_json'] + + def __new__(cls, *args, **kwargs): + """ + Due to the way pickling works in python 3, we need to make sure + the box config is created as early as possible. + """ + obj = super(Box, cls).__new__(cls, *args, **kwargs) + obj._box_config = _get_box_config(cls, kwargs) + return obj + + def __init__(self, *args, **kwargs): + self._box_config = _get_box_config(self.__class__, kwargs) + if self._box_config['ordered_box']: + self._box_config['__ordered_box_values'] = [] + if (not self._box_config['conversion_box'] and + self._box_config['box_duplicates'] != "ignore"): + raise BoxError('box_duplicates are only for conversion_boxes') + if len(args) == 1: + if isinstance(args[0], basestring): + raise ValueError('Cannot extrapolate Box from string') + if isinstance(args[0], Mapping): + for k, v in args[0].items(): + if v is args[0]: + v = self + self[k] = v + self.__add_ordered(k) + elif isinstance(args[0], Iterable): + for k, v in args[0]: + self[k] = v + self.__add_ordered(k) + + else: + raise ValueError('First argument must be mapping or iterable') + elif args: + raise TypeError('Box expected at most 1 argument, ' + 'got {0}'.format(len(args))) + + box_it = kwargs.pop('box_it_up', False) + for k, v in kwargs.items(): + if args and isinstance(args[0], Mapping) and v is args[0]: + v = self + self[k] = v + self.__add_ordered(k) + + if (self._box_config['frozen_box'] or box_it or + self._box_config['box_duplicates'] != 'ignore'): + self.box_it_up() + + self._box_config['__created'] = True + + def __add_ordered(self, key): + if (self._box_config['ordered_box'] and + key not in self._box_config['__ordered_box_values']): + self._box_config['__ordered_box_values'].append(key) + + def box_it_up(self): + """ + Perform value lookup for all items in current dictionary, + generating all sub Box objects, while also running `box_it_up` on + any of those sub box objects. + """ + for k in self: + _conversion_checks(k, self.keys(), self._box_config, + check_only=True) + if self[k] is not self and hasattr(self[k], 'box_it_up'): + self[k].box_it_up() + + def __hash__(self): + if self._box_config['frozen_box']: + hashing = 54321 + for item in self.items(): + hashing ^= hash(item) + return hashing + raise TypeError("unhashable type: 'Box'") + + def __dir__(self): + allowed = string.ascii_letters + string.digits + '_' + kill_camel = self._box_config['camel_killer_box'] + items = set(dir(dict) + ['to_dict', 'to_json', + 'from_json', 'box_it_up']) + # Only show items accessible by dot notation + for key in self.keys(): + key = _safe_key(key) + if (' ' not in key and key[0] not in string.digits and + key not in kwlist): + for letter in key: + if letter not in allowed: + break + else: + items.add(key) + + for key in self.keys(): + key = _safe_key(key) + if key not in items: + if self._box_config['conversion_box']: + key = _safe_attr(key, camel_killer=kill_camel, + replacement_char=self._box_config[ + 'box_safe_prefix']) + if key: + items.add(key) + if kill_camel: + snake_key = _camel_killer(key) + if snake_key: + items.remove(key) + items.add(snake_key) + + if yaml_support: + items.add('to_yaml') + items.add('from_yaml') + + return list(items) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + if isinstance(default, dict) and not isinstance(default, Box): + return Box(default) + if isinstance(default, list) and not isinstance(default, BoxList): + return BoxList(default) + return default + + def copy(self): + return self.__class__(super(self.__class__, self).copy()) + + def __copy__(self): + return self.__class__(super(self.__class__, self).copy()) + + def __deepcopy__(self, memodict=None): + out = self.__class__() + memodict = memodict or {} + memodict[id(self)] = out + for k, v in self.items(): + out[copy.deepcopy(k, memodict)] = copy.deepcopy(v, memodict) + return out + + def __setstate__(self, state): + self._box_config = state['_box_config'] + self.__dict__.update(state) + + def __getitem__(self, item, _ignore_default=False): + try: + value = super(Box, self).__getitem__(item) + except KeyError as err: + if item == '_box_config': + raise BoxKeyError('_box_config should only exist as an ' + 'attribute and is never defaulted') + if self._box_config['default_box'] and not _ignore_default: + return self.__get_default(item) + raise BoxKeyError(str(err)) + else: + return self.__convert_and_store(item, value) + + def keys(self): + if self._box_config['ordered_box']: + return self._box_config['__ordered_box_values'] + return super(Box, self).keys() + + def values(self): + return [self[x] for x in self.keys()] + + def items(self): + return [(x, self[x]) for x in self.keys()] + + def __get_default(self, item): + default_value = self._box_config['default_box_attr'] + if default_value is self.__class__: + return self.__class__(__box_heritage=(self, item), + **self.__box_config()) + elif isinstance(default_value, Callable): + return default_value() + elif hasattr(default_value, 'copy'): + return default_value.copy() + return default_value + + def __box_config(self): + out = {} + for k, v in self._box_config.copy().items(): + if not k.startswith("__"): + out[k] = v + return out + + def __convert_and_store(self, item, value): + if item in self._box_config['__converted']: + return value + if isinstance(value, dict) and not isinstance(value, Box): + value = self.__class__(value, __box_heritage=(self, item), + **self.__box_config()) + self[item] = value + elif isinstance(value, list) and not isinstance(value, BoxList): + if self._box_config['frozen_box']: + value = _recursive_tuples(value, self.__class__, + recreate_tuples=self._box_config[ + 'modify_tuples_box'], + __box_heritage=(self, item), + **self.__box_config()) + else: + value = BoxList(value, __box_heritage=(self, item), + box_class=self.__class__, + **self.__box_config()) + self[item] = value + elif (self._box_config['modify_tuples_box'] and + isinstance(value, tuple)): + value = _recursive_tuples(value, self.__class__, + recreate_tuples=True, + __box_heritage=(self, item), + **self.__box_config()) + self[item] = value + self._box_config['__converted'].add(item) + return value + + def __create_lineage(self): + if (self._box_config['__box_heritage'] and + self._box_config['__created']): + past, item = self._box_config['__box_heritage'] + if not past[item]: + past[item] = self + self._box_config['__box_heritage'] = None + + def __getattr__(self, item): + try: + try: + value = self.__getitem__(item, _ignore_default=True) + except KeyError: + value = object.__getattribute__(self, item) + except AttributeError as err: + if item == "__getstate__": + raise AttributeError(item) + if item == '_box_config': + raise BoxError('_box_config key must exist') + kill_camel = self._box_config['camel_killer_box'] + if self._box_config['conversion_box'] and item: + k = _conversion_checks(item, self.keys(), self._box_config) + if k: + return self.__getitem__(k) + if kill_camel: + for k in self.keys(): + if item == _camel_killer(k): + return self.__getitem__(k) + if self._box_config['default_box']: + return self.__get_default(item) + raise BoxKeyError(str(err)) + else: + if item == '_box_config': + return value + return self.__convert_and_store(item, value) + + def __setitem__(self, key, value): + if (key != '_box_config' and self._box_config['__created'] and + self._box_config['frozen_box']): + raise BoxError('Box is frozen') + if self._box_config['conversion_box']: + _conversion_checks(key, self.keys(), self._box_config, + check_only=True, pre_check=True) + super(Box, self).__setitem__(key, value) + self.__add_ordered(key) + self.__create_lineage() + + def __setattr__(self, key, value): + if (key != '_box_config' and self._box_config['frozen_box'] and + self._box_config['__created']): + raise BoxError('Box is frozen') + if key in self._protected_keys: + raise AttributeError("Key name '{0}' is protected".format(key)) + if key == '_box_config': + return object.__setattr__(self, key, value) + try: + object.__getattribute__(self, key) + except (AttributeError, UnicodeEncodeError): + if (key not in self.keys() and + (self._box_config['conversion_box'] or + self._box_config['camel_killer_box'])): + if self._box_config['conversion_box']: + k = _conversion_checks(key, self.keys(), + self._box_config) + self[key if not k else k] = value + elif self._box_config['camel_killer_box']: + for each_key in self: + if key == _camel_killer(each_key): + self[each_key] = value + break + else: + self[key] = value + else: + object.__setattr__(self, key, value) + self.__add_ordered(key) + self.__create_lineage() + + def __delitem__(self, key): + if self._box_config['frozen_box']: + raise BoxError('Box is frozen') + super(Box, self).__delitem__(key) + if (self._box_config['ordered_box'] and + key in self._box_config['__ordered_box_values']): + self._box_config['__ordered_box_values'].remove(key) + + def __delattr__(self, item): + if self._box_config['frozen_box']: + raise BoxError('Box is frozen') + if item == '_box_config': + raise BoxError('"_box_config" is protected') + if item in self._protected_keys: + raise AttributeError("Key name '{0}' is protected".format(item)) + try: + object.__getattribute__(self, item) + except AttributeError: + del self[item] + else: + object.__delattr__(self, item) + if (self._box_config['ordered_box'] and + item in self._box_config['__ordered_box_values']): + self._box_config['__ordered_box_values'].remove(item) + + def pop(self, key, *args): + if args: + if len(args) != 1: + raise BoxError('pop() takes only one optional' + ' argument "default"') + try: + item = self[key] + except KeyError: + return args[0] + else: + del self[key] + return item + try: + item = self[key] + except KeyError: + raise BoxKeyError('{0}'.format(key)) + else: + del self[key] + return item + + def clear(self): + self._box_config['__ordered_box_values'] = [] + super(Box, self).clear() + + def popitem(self): + try: + key = next(self.__iter__()) + except StopIteration: + raise BoxKeyError('Empty box') + return key, self.pop(key) + + def __repr__(self): + return ''.format(str(self.to_dict())) + + def __str__(self): + return str(self.to_dict()) + + def __iter__(self): + for key in self.keys(): + yield key + + def __reversed__(self): + for key in reversed(list(self.keys())): + yield key + + def to_dict(self): + """ + Turn the Box and sub Boxes back into a native + python dictionary. + + :return: python dictionary of this Box + """ + out_dict = dict(self) + for k, v in out_dict.items(): + if v is self: + out_dict[k] = out_dict + elif hasattr(v, 'to_dict'): + out_dict[k] = v.to_dict() + elif hasattr(v, 'to_list'): + out_dict[k] = v.to_list() + return out_dict + + def update(self, item=None, **kwargs): + if not item: + item = kwargs + iter_over = item.items() if hasattr(item, 'items') else item + for k, v in iter_over: + if isinstance(v, dict): + # Box objects must be created in case they are already + # in the `converted` box_config set + v = self.__class__(v) + if k in self and isinstance(self[k], dict): + self[k].update(v) + continue + if isinstance(v, list): + v = BoxList(v) + try: + self.__setattr__(k, v) + except (AttributeError, TypeError): + self.__setitem__(k, v) + + def setdefault(self, item, default=None): + if item in self: + return self[item] + + if isinstance(default, dict): + default = self.__class__(default) + if isinstance(default, list): + default = BoxList(default) + self[item] = default + return default + + def to_json(self, filename=None, + encoding="utf-8", errors="strict", **json_kwargs): + """ + Transform the Box object into a JSON string. + + :param filename: If provided will save to file + :param encoding: File encoding + :param errors: How to handle encoding errors + :param json_kwargs: additional arguments to pass to json.dump(s) + :return: string of JSON or return of `json.dump` + """ + return _to_json(self.to_dict(), filename=filename, + encoding=encoding, errors=errors, **json_kwargs) + + @classmethod + def from_json(cls, json_string=None, filename=None, + encoding="utf-8", errors="strict", **kwargs): + """ + Transform a json object string into a Box object. If the incoming + json is a list, you must use BoxList.from_json. + + :param json_string: string to pass to `json.loads` + :param filename: filename to open and pass to `json.load` + :param encoding: File encoding + :param errors: How to handle encoding errors + :param kwargs: parameters to pass to `Box()` or `json.loads` + :return: Box object from json data + """ + bx_args = {} + for arg in kwargs.copy(): + if arg in BOX_PARAMETERS: + bx_args[arg] = kwargs.pop(arg) + + data = _from_json(json_string, filename=filename, + encoding=encoding, errors=errors, **kwargs) + + if not isinstance(data, dict): + raise BoxError('json data not returned as a dictionary, ' + 'but rather a {0}'.format(type(data).__name__)) + return cls(data, **bx_args) + + if yaml_support: + def to_yaml(self, filename=None, default_flow_style=False, + encoding="utf-8", errors="strict", + **yaml_kwargs): + """ + Transform the Box object into a YAML string. + + :param filename: If provided will save to file + :param default_flow_style: False will recursively dump dicts + :param encoding: File encoding + :param errors: How to handle encoding errors + :param yaml_kwargs: additional arguments to pass to yaml.dump + :return: string of YAML or return of `yaml.dump` + """ + return _to_yaml(self.to_dict(), filename=filename, + default_flow_style=default_flow_style, + encoding=encoding, errors=errors, **yaml_kwargs) + + @classmethod + def from_yaml(cls, yaml_string=None, filename=None, + encoding="utf-8", errors="strict", + loader=yaml.SafeLoader, **kwargs): + """ + Transform a yaml object string into a Box object. + + :param yaml_string: string to pass to `yaml.load` + :param filename: filename to open and pass to `yaml.load` + :param encoding: File encoding + :param errors: How to handle encoding errors + :param loader: YAML Loader, defaults to SafeLoader + :param kwargs: parameters to pass to `Box()` or `yaml.load` + :return: Box object from yaml data + """ + bx_args = {} + for arg in kwargs.copy(): + if arg in BOX_PARAMETERS: + bx_args[arg] = kwargs.pop(arg) + + data = _from_yaml(yaml_string=yaml_string, filename=filename, + encoding=encoding, errors=errors, + Loader=loader, **kwargs) + if not isinstance(data, dict): + raise BoxError('yaml data not returned as a dictionary' + 'but rather a {0}'.format(type(data).__name__)) + return cls(data, **bx_args) + + +class BoxList(list): + """ + Drop in replacement of list, that converts added objects to Box or BoxList + objects as necessary. + """ + + def __init__(self, iterable=None, box_class=Box, **box_options): + self.box_class = box_class + self.box_options = box_options + self.box_org_ref = self.box_org_ref = id(iterable) if iterable else 0 + if iterable: + for x in iterable: + self.append(x) + if box_options.get('frozen_box'): + def frozen(*args, **kwargs): + raise BoxError('BoxList is frozen') + + for method in ['append', 'extend', 'insert', 'pop', + 'remove', 'reverse', 'sort']: + self.__setattr__(method, frozen) + + def __delitem__(self, key): + if self.box_options.get('frozen_box'): + raise BoxError('BoxList is frozen') + super(BoxList, self).__delitem__(key) + + def __setitem__(self, key, value): + if self.box_options.get('frozen_box'): + raise BoxError('BoxList is frozen') + super(BoxList, self).__setitem__(key, value) + + def append(self, p_object): + if isinstance(p_object, dict): + try: + p_object = self.box_class(p_object, **self.box_options) + except AttributeError as err: + if 'box_class' in self.__dict__: + raise err + elif isinstance(p_object, list): + try: + p_object = (self if id(p_object) == self.box_org_ref else + BoxList(p_object)) + except AttributeError as err: + if 'box_org_ref' in self.__dict__: + raise err + super(BoxList, self).append(p_object) + + def extend(self, iterable): + for item in iterable: + self.append(item) + + def insert(self, index, p_object): + if isinstance(p_object, dict): + p_object = self.box_class(p_object, **self.box_options) + elif isinstance(p_object, list): + p_object = (self if id(p_object) == self.box_org_ref else + BoxList(p_object)) + super(BoxList, self).insert(index, p_object) + + def __repr__(self): + return "".format(self.to_list()) + + def __str__(self): + return str(self.to_list()) + + def __copy__(self): + return BoxList((x for x in self), + self.box_class, + **self.box_options) + + def __deepcopy__(self, memodict=None): + out = self.__class__() + memodict = memodict or {} + memodict[id(self)] = out + for k in self: + out.append(copy.deepcopy(k)) + return out + + def __hash__(self): + if self.box_options.get('frozen_box'): + hashing = 98765 + hashing ^= hash(tuple(self)) + return hashing + raise TypeError("unhashable type: 'BoxList'") + + def to_list(self): + new_list = [] + for x in self: + if x is self: + new_list.append(new_list) + elif isinstance(x, Box): + new_list.append(x.to_dict()) + elif isinstance(x, BoxList): + new_list.append(x.to_list()) + else: + new_list.append(x) + return new_list + + def to_json(self, filename=None, + encoding="utf-8", errors="strict", + multiline=False, **json_kwargs): + """ + Transform the BoxList object into a JSON string. + + :param filename: If provided will save to file + :param encoding: File encoding + :param errors: How to handle encoding errors + :param multiline: Put each item in list onto it's own line + :param json_kwargs: additional arguments to pass to json.dump(s) + :return: string of JSON or return of `json.dump` + """ + if filename and multiline: + lines = [_to_json(item, filename=False, encoding=encoding, + errors=errors, **json_kwargs) for item in self] + with open(filename, 'w', encoding=encoding, errors=errors) as f: + f.write("\n".join(lines).decode('utf-8') if + sys.version_info < (3, 0) else "\n".join(lines)) + else: + return _to_json(self.to_list(), filename=filename, + encoding=encoding, errors=errors, **json_kwargs) + + @classmethod + def from_json(cls, json_string=None, filename=None, encoding="utf-8", + errors="strict", multiline=False, **kwargs): + """ + Transform a json object string into a BoxList object. If the incoming + json is a dict, you must use Box.from_json. + + :param json_string: string to pass to `json.loads` + :param filename: filename to open and pass to `json.load` + :param encoding: File encoding + :param errors: How to handle encoding errors + :param multiline: One object per line + :param kwargs: parameters to pass to `Box()` or `json.loads` + :return: BoxList object from json data + """ + bx_args = {} + for arg in kwargs.copy(): + if arg in BOX_PARAMETERS: + bx_args[arg] = kwargs.pop(arg) + + data = _from_json(json_string, filename=filename, encoding=encoding, + errors=errors, multiline=multiline, **kwargs) + + if not isinstance(data, list): + raise BoxError('json data not returned as a list, ' + 'but rather a {0}'.format(type(data).__name__)) + return cls(data, **bx_args) + + if yaml_support: + def to_yaml(self, filename=None, default_flow_style=False, + encoding="utf-8", errors="strict", + **yaml_kwargs): + """ + Transform the BoxList object into a YAML string. + + :param filename: If provided will save to file + :param default_flow_style: False will recursively dump dicts + :param encoding: File encoding + :param errors: How to handle encoding errors + :param yaml_kwargs: additional arguments to pass to yaml.dump + :return: string of YAML or return of `yaml.dump` + """ + return _to_yaml(self.to_list(), filename=filename, + default_flow_style=default_flow_style, + encoding=encoding, errors=errors, **yaml_kwargs) + + @classmethod + def from_yaml(cls, yaml_string=None, filename=None, + encoding="utf-8", errors="strict", + loader=yaml.SafeLoader, + **kwargs): + """ + Transform a yaml object string into a BoxList object. + + :param yaml_string: string to pass to `yaml.load` + :param filename: filename to open and pass to `yaml.load` + :param encoding: File encoding + :param errors: How to handle encoding errors + :param loader: YAML Loader, defaults to SafeLoader + :param kwargs: parameters to pass to `BoxList()` or `yaml.load` + :return: BoxList object from yaml data + """ + bx_args = {} + for arg in kwargs.copy(): + if arg in BOX_PARAMETERS: + bx_args[arg] = kwargs.pop(arg) + + data = _from_yaml(yaml_string=yaml_string, filename=filename, + encoding=encoding, errors=errors, + Loader=loader, **kwargs) + if not isinstance(data, list): + raise BoxError('yaml data not returned as a list' + 'but rather a {0}'.format(type(data).__name__)) + return cls(data, **bx_args) + + def box_it_up(self): + for v in self: + if hasattr(v, 'box_it_up') and v is not self: + v.box_it_up() + + +class ConfigBox(Box): + """ + Modified box object to add object transforms. + + Allows for build in transforms like: + + cns = ConfigBox(my_bool='yes', my_int='5', my_list='5,4,3,3,2') + + cns.bool('my_bool') # True + cns.int('my_int') # 5 + cns.list('my_list', mod=lambda x: int(x)) # [5, 4, 3, 3, 2] + """ + + _protected_keys = dir({}) + ['to_dict', 'bool', 'int', 'float', + 'list', 'getboolean', 'to_json', 'to_yaml', + 'getfloat', 'getint', + 'from_json', 'from_yaml'] + + def __getattr__(self, item): + """Config file keys are stored in lower case, be a little more + loosey goosey""" + try: + return super(ConfigBox, self).__getattr__(item) + except AttributeError: + return super(ConfigBox, self).__getattr__(item.lower()) + + def __dir__(self): + return super(ConfigBox, self).__dir__() + ['bool', 'int', 'float', + 'list', 'getboolean', + 'getfloat', 'getint'] + + def bool(self, item, default=None): + """ Return value of key as a boolean + + :param item: key of value to transform + :param default: value to return if item does not exist + :return: approximated bool of value + """ + try: + item = self.__getattr__(item) + except AttributeError as err: + if default is not None: + return default + raise err + + if isinstance(item, (bool, int)): + return bool(item) + + if (isinstance(item, str) and + item.lower() in ('n', 'no', 'false', 'f', '0')): + return False + + return True if item else False + + def int(self, item, default=None): + """ Return value of key as an int + + :param item: key of value to transform + :param default: value to return if item does not exist + :return: int of value + """ + try: + item = self.__getattr__(item) + except AttributeError as err: + if default is not None: + return default + raise err + return int(item) + + def float(self, item, default=None): + """ Return value of key as a float + + :param item: key of value to transform + :param default: value to return if item does not exist + :return: float of value + """ + try: + item = self.__getattr__(item) + except AttributeError as err: + if default is not None: + return default + raise err + return float(item) + + def list(self, item, default=None, spliter=",", strip=True, mod=None): + """ Return value of key as a list + + :param item: key of value to transform + :param mod: function to map against list + :param default: value to return if item does not exist + :param spliter: character to split str on + :param strip: clean the list with the `strip` + :return: list of items + """ + try: + item = self.__getattr__(item) + except AttributeError as err: + if default is not None: + return default + raise err + if strip: + item = item.lstrip('[').rstrip(']') + out = [x.strip() if strip else x for x in item.split(spliter)] + if mod: + return list(map(mod, out)) + return out + + # loose configparser compatibility + + def getboolean(self, item, default=None): + return self.bool(item, default) + + def getint(self, item, default=None): + return self.int(item, default) + + def getfloat(self, item, default=None): + return self.float(item, default) + + def __repr__(self): + return ''.format(str(self.to_dict())) + + +class SBox(Box): + """ + ShorthandBox (SBox) allows for + property access of `dict` `json` and `yaml` + """ + _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml', + 'json', 'yaml', 'from_yaml', 'from_json', + 'dict'] + + @property + def dict(self): + return self.to_dict() + + @property + def json(self): + return self.to_json() + + if yaml_support: + @property + def yaml(self): + return self.to_yaml() + + def __repr__(self): + return ''.format(str(self.to_dict())) + +# C is a dict storing all the configuration +C = Box() + +# shortcut for C.model +M = Box() diff --git a/deploy/creepageDistanceModel/models_creepage.py b/deploy/creepageDistanceModel/models_creepage.py new file mode 100644 index 0000000..6ae75be --- /dev/null +++ b/deploy/creepageDistanceModel/models_creepage.py @@ -0,0 +1,523 @@ +import os + +import matplotlib as mpl +import matplotlib.pyplot as plt +import uuid +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from .config import M + + +# 用于归一化的宽度值 +NORMALIZATION_WIDTH = 64 +NORMALIZATION_HEIGHT = 512 +# 像素最大值为255 +PIXS_MAX_VALUE = 255.0 +# 数据类型 +TB_DATATYPE = "tb" +LR_DATATYPE = "lr" +# 准确率容错距离 +ACC_PX_THRESH=16 +# 随机种子 +RANDOM_SEED = 1024 + +__all__ = ["HourglassNet", "hg"] + + +class Bottleneck2D(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck2D, self).__init__() + + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + out = self.conv1(out) + + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + +class Hourglass(nn.Module): + def __init__(self, block, num_blocks, planes, depth): + super(Hourglass, self).__init__() + self.depth = depth + self.block = block + self.hg = self._make_hour_glass(block, num_blocks, planes, depth) + + def _make_residual(self, block, num_blocks, planes): + layers = [] + for i in range(0, num_blocks): + layers.append(block(planes * block.expansion, planes)) + return nn.Sequential(*layers) + + def _make_hour_glass(self, block, num_blocks, planes, depth): + hg = [] + for i in range(depth): + res = [] + for j in range(3): + res.append(self._make_residual(block, num_blocks, planes)) + if i == 0: + res.append(self._make_residual(block, num_blocks, planes)) + hg.append(nn.ModuleList(res)) + return nn.ModuleList(hg) + + def _hour_glass_forward(self, n, x): + up1 = self.hg[n - 1][0](x) + low1 = F.max_pool2d(x, 2, stride=2) + low1 = self.hg[n - 1][1](low1) + + if n > 1: + low2 = self._hour_glass_forward(n - 1, low1) + else: + low2 = self.hg[n - 1][3](low1) + low3 = self.hg[n - 1][2](low2) + up2 = F.interpolate(low3, scale_factor=2) + out = up1 + up2 + return out + + def forward(self, x): + return self._hour_glass_forward(self.depth, x) + + +class HourglassNet(nn.Module): + """Hourglass model from Newell et al ECCV 2016""" + + def __init__(self, block, head, depth, num_stacks, num_blocks, num_classes): + super(HourglassNet, self).__init__() + + self.inplanes = 64 + self.num_feats = 128 + self.num_stacks = num_stacks + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_residual(block, self.inplanes, 1) + self.layer2 = self._make_residual(block, self.inplanes, 1) + self.layer3 = self._make_residual(block, self.num_feats, 1) + self.maxpool = nn.MaxPool2d(2, stride=2) + + # build hourglass modules + ch = self.num_feats * block.expansion + # vpts = [] + hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] + for i in range(num_stacks): + hg.append(Hourglass(block, num_blocks, self.num_feats, depth)) + res.append(self._make_residual(block, self.num_feats, num_blocks)) + fc.append(self._make_fc(ch, ch)) + score.append(head(ch, num_classes)) + # vpts.append(VptsHead(ch)) + # vpts.append(nn.Linear(ch, 9)) + # score.append(nn.Conv2d(ch, num_classes, kernel_size=1)) + # score[i].bias.data[0] += 4.6 + # score[i].bias.data[2] += 4.6 + if i < num_stacks - 1: + fc_.append(nn.Conv2d(ch, ch, kernel_size=1)) + score_.append(nn.Conv2d(num_classes, ch, kernel_size=1)) + self.hg = nn.ModuleList(hg) + self.res = nn.ModuleList(res) + self.fc = nn.ModuleList(fc) + self.score = nn.ModuleList(score) + # self.vpts = nn.ModuleList(vpts) + self.fc_ = nn.ModuleList(fc_) + self.score_ = nn.ModuleList(score_) + + def _make_residual(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + ) + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_fc(self, inplanes, outplanes): + bn = nn.BatchNorm2d(inplanes) + conv = nn.Conv2d(inplanes, outplanes, kernel_size=1) + return nn.Sequential(conv, bn, self.relu) + + def forward(self, x): + out = [] + # out_vps = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.maxpool(x) + x = self.layer2(x) + x = self.layer3(x) + + for i in range(self.num_stacks): + y = self.hg[i](x) + y = self.res[i](y) + y = self.fc[i](y) + score = self.score[i](y) + # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1)) + # pre_vpts = pre_vpts.reshape(-1, 256) + # vpts = self.vpts[i](x) + out.append(score) + # out_vps.append(vpts) + if i < self.num_stacks - 1: + fc_ = self.fc_[i](y) + score_ = self.score_[i](score) + x = x + fc_ + score_ + + return out[::-1], y # , out_vps[::-1] + + +def hg(**kwargs): + model = HourglassNet( + Bottleneck2D, + head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)), + depth=kwargs["depth"], + num_stacks=kwargs["num_stacks"], + num_blocks=kwargs["num_blocks"], + num_classes=kwargs["num_classes"], + ) + return model + + +FEATURE_DIM = 8 + + +class LineVectorizer(nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None] + self.register_buffer("lambda_", lambda_) + self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0 + + self.fc1 = nn.Conv2d(256, M.dim_loi, 1) + scale_factor = M.n_pts0 // M.n_pts1 + self.pooling = nn.MaxPool1d(scale_factor, scale_factor) + self.fc2 = nn.Sequential( + nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc), + nn.ReLU(inplace=True), + nn.Linear(M.dim_fc, M.dim_fc), + nn.ReLU(inplace=True), + nn.Linear(M.dim_fc, 1), + ) + + def forward(self, image,junc,jtyp,Lpos): + result = self.backbone(image) + h = result["preds"] + x = self.fc1(result["feature"]) + n_batch, n_channel, row, col = x.shape + + xs, ys, fs, ps, idx, = [], [], [], [], [0] + i = 0 + p, label, feat = self.sample_lines( + junc,jtyp,Lpos, h["jmap"][i], h["joff"][i] + ) + # print("p.shape:", p.shape) + ys.append(label) + ps.append(p) + fs.append(feat) + + p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5 + p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY] + px, py = p[:, 1].contiguous(), p[:, 0].contiguous() + px0 = px.floor().clamp(min=0, max=int(NORMALIZATION_WIDTH / 4)-1) + py0 = py.floor().clamp(min=0, max=int(NORMALIZATION_HEIGHT / 4)-1) + px1 = (px0 + 1).clamp(min=0, max=int(NORMALIZATION_WIDTH / 4)-1) + py1 = (py0 + 1).clamp(min=0, max=int(NORMALIZATION_HEIGHT / 4)-1) + px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long() + + # xp: [N_LINE, N_CHANNEL, N_POINT] + xp = ( + ( + x[i, :, py0l, px0l] * (px1 - px) * (py1 - py) + + x[i, :, py0l, px1l] * (px - px0) * (py1 - py) + + x[i, :, py1l, px0l] * (px1 - px) * (py - py0) + + x[i, :, py1l, px1l] * (px - px0) * (py - py0) + ) + .reshape(n_channel, -1, M.n_pts0) + .permute(1, 0, 2) + ) + xp = self.pooling(xp) + xs.append(xp) + idx.append(idx[-1] + xp.shape[0]) + + x, y = torch.cat(xs), torch.cat(ys) + f = torch.cat(fs) + x = x.reshape(-1, M.n_pts1 * M.dim_loi) + x = torch.cat([x, f], 1) + x = self.fc2(x).flatten() + + p = torch.cat(ps) + s = torch.sigmoid(x) + b = s > 0.5 + lines = [] + score = [] + for i in range(n_batch): + p0 = p[idx[i] : idx[i + 1]] + s0 = s[idx[i] : idx[i + 1]] + mask = b[idx[i] : idx[i + 1]] + p0 = p0[mask] + s0 = s0[mask] + if len(p0) == 0: + lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device)) + score.append(torch.zeros([1, M.n_out_line], device=p.device)) + else: + v, arg = torch.sort(s0,descending=True) + # arg = torch.argsort(s0, descending=True) + p0, s0 = p0[arg], s0[arg] + lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)]) + score.append(s0[None, torch.arange(M.n_out_line) % len(s0)]) + return torch.cat(lines), torch.cat(score) + + def sample_lines(self, junc,jtyp,Lpos, jmap, joff): + with torch.no_grad(): + n_type = jmap.shape[0] + jmap = non_maximum_suppression(jmap).reshape(n_type, -1) + # jmap = jmap.reshape(n_type, -1) + joff = joff.reshape(n_type, 2, -1) + max_K = M.n_dyn_junc // n_type + N = len(junc) + K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K) + if K < 2: + K = 2 + device = jmap.device + + # index: [N_TYPE, K] + score, index = torch.topk(jmap, k=K) + y = (index / int(NORMALIZATION_WIDTH / 4)).float() + torch.gather(joff[:, 0], 1, index) + 0.5 + x = (index % int(NORMALIZATION_WIDTH / 4)).float() + torch.gather(joff[:, 1], 1, index) + 0.5 + + # xy: [N_TYPE, K, 2] + xy = torch.cat([y[..., None], x[..., None]], dim=-1) + xy_ = xy[..., None, :] + del x, y, index + + # dist: [N_TYPE, K, N] + dist = torch.sum((xy_ - junc) ** 2, -1) + cost, match = torch.min(dist, -1) + + # xy: [N_TYPE * K, 2] + # match: [N_TYPE, K] + for t in range(n_type): + match[t, jtyp[match[t]] != t] = N + match[cost > 1.5 * 1.5] = N + match = match.flatten() + + _ = torch.arange(n_type * K, device=device) + u, v = torch.meshgrid(_, _) + u, v = u.flatten(), v.flatten() + up, vp = match[u], match[v] + label = Lpos[up, vp] + + c = u < v + # sample lines + u, v, label = u[c], v[c], label[c] + xy = xy.reshape(n_type * K, 2) + xyu, xyv = xy[u], xy[v] + + u2v = xyu - xyv + u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6) + feat = torch.cat( + [ + xyu / torch.tensor([int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4)]).to(device) * M.use_cood, + xyv / torch.tensor([int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4)]).to(device) * M.use_cood, + u2v * M.use_slop, + (u[:, None] > K).float(), + (v[:, None] > K).float(), + ], + 1, + ) + line = torch.cat([xyu[:, None], xyv[:, None]], 1) + + return line, label.float(), feat + + +def non_maximum_suppression(a): + # output = F.max_pool2d(a, 3,stride=1) + # ap = F.interpolate(output.unsqueeze(0), size=a.shape[1:], mode='bilinear', align_corners=True) + # ap = ap.squeeze(0) + # mask = (a == ap).float().clamp(min=0.0) + # return a * mask + + # au = a.unsqueeze(0) + # output, indices = F.max_pool2d(au, 3, stride=1, return_indices=True) + # ap = F.max_unpool2d(output, indices, 3, stride=1,output_size=au.shape) + # ap = ap.squeeze(0) + # mask = (a == ap).float().clamp(min=0.0) + # return a * mask + # 等价于下面的 + ap = F.max_pool2d(a.unsqueeze(0), 3, stride=1, padding=1) + ap = ap.squeeze(0) + mask = (a == ap).float().clamp(min=0.0) + return a * mask + + +class Bottleneck1D(nn.Module): + def __init__(self, inplanes, outplanes): + super(Bottleneck1D, self).__init__() + + planes = outplanes // 2 + self.op = nn.Sequential( + nn.BatchNorm1d(inplanes), + nn.ReLU(inplace=True), + nn.Conv1d(inplanes, planes, kernel_size=1), + nn.BatchNorm1d(planes), + nn.ReLU(inplace=True), + nn.Conv1d(planes, planes, kernel_size=3, padding=1), + nn.BatchNorm1d(planes), + nn.ReLU(inplace=True), + nn.Conv1d(planes, outplanes, kernel_size=1), + ) + + def forward(self, x): + return x + self.op(x) + + +class MultitaskHead(nn.Module): + def __init__(self, input_channels, num_class): + super(MultitaskHead, self).__init__() + + m = int(input_channels / 4) + heads = [] + for output_channels in sum(M.head_size, []): + heads.append( + nn.Sequential( + nn.Conv2d(input_channels, m, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(m, output_channels, kernel_size=1), + ) + ) + self.heads = nn.ModuleList(heads) + assert num_class == sum(sum(M.head_size, [])) + + def forward(self, x): + return torch.cat([head(x) for head in self.heads], dim=1) + + +class MultitaskLearner(nn.Module): + def __init__(self, backbone): + super(MultitaskLearner, self).__init__() + self.backbone = backbone + head_size = M.head_size + self.num_class = sum(sum(head_size, [])) + self.head_off = np.cumsum([sum(h) for h in head_size]) + + def forward(self, image): + outputs, feature = self.backbone(image) + result = {"feature": feature} + batch, channel, row, col = outputs[0].shape + + n_jtyp = 1 # batch_size + + offset = self.head_off + output=outputs[0] + output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous() + jmap = output[0 : offset[0]].reshape(n_jtyp, 2, batch, row, col) + joff = output[offset[1] : offset[2]].reshape(n_jtyp, 2, batch, row, col) + result["preds"] = { + "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1], + "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5, + } + return result + +def pline(x1, y1, x2, y2, x, y): + px = x2 - x1 + py = y2 - y1 + dd = px * px + py * py + u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd)) + dx = x1 + u * px - x + dy = y1 + u * py - y + return dx * dx + dy * dy + + +def plambda(x1, y1, x2, y2, x, y): + px = x2 - x1 + py = y2 - y1 + dd = px * px + py * py + return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd)) + +def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False): + nlines, nscores = [], [] + for (p, q), score in zip(lines, scores): + start, end = 0, 1 + for a, b in nlines: + if ( + min( + max(pline(*p, *q, *a), pline(*p, *q, *b)), + max(pline(*a, *b, *p), pline(*a, *b, *q)), + ) + > threshold ** 2 + ): + continue + lambda_a = plambda(*p, *q, *a) + lambda_b = plambda(*p, *q, *b) + if lambda_a > lambda_b: + lambda_a, lambda_b = lambda_b, lambda_a + lambda_a -= tol + lambda_b += tol + + # case 1: skip (if not do_clip) + if start < lambda_a and lambda_b < end: + continue + + # not intersect + if lambda_b < start or lambda_a > end: + continue + + # cover + if lambda_a <= start and end <= lambda_b: + start = 10 + break + + # case 2 & 3: + if lambda_a <= start and start <= lambda_b: + start = lambda_b + if lambda_a <= end and end <= lambda_b: + end = lambda_a + + if start >= end: + break + + if start >= end: + continue + nlines.append(np.array([p + (q - p) * start, p + (q - p) * end])) + nscores.append(score) + return np.array(nlines), np.array(nscores) \ No newline at end of file diff --git a/deploy/creepageDistanceModel/wireframe.yaml b/deploy/creepageDistanceModel/wireframe.yaml new file mode 100644 index 0000000..5f5ecd8 --- /dev/null +++ b/deploy/creepageDistanceModel/wireframe.yaml @@ -0,0 +1,48 @@ +model: + image: + mean: [109.730, 103.832, 98.681] + stddev: [22.275, 22.124, 23.229] + + batch_size: 32 + batch_size_eval: 2 + + # backbone multi-task parameters + head_size: [[2], [1], [2]] + + # backbone parameters + backbone: stacked_hourglass + depth: 4 + num_stacks: 2 + num_blocks: 1 + + # sampler parameters + ## static sampler + n_stc_posl: 300 + n_stc_negl: 40 + + ## dynamic sampler + n_dyn_junc: 100 + n_dyn_posl: 100 + n_dyn_negl: 80 + n_dyn_othr: 200 + + # LOIPool layer parameters + n_pts0: 32 + n_pts1: 8 + + # line verification network parameters + dim_loi: 128 + dim_fc: 1024 + + # maximum junction and line outputs + n_out_junc: 250 + n_out_line: 50 + + # additional ablation study parameters + use_cood: 0 + use_slop: 0 + + # junction threashold for evaluation (See #5) + eval_junc_thres: 0.008 + + diff --git a/deploy/infer_onnx_onnxruntime_cpu.py b/deploy/infer_onnx_onnxruntime_cpu.py new file mode 100644 index 0000000..b8136b7 --- /dev/null +++ b/deploy/infer_onnx_onnxruntime_cpu.py @@ -0,0 +1,87 @@ +# !/usr/bin/env python +# -- coding: utf-8 -- +# @Author zengxiaohui +# Datatime:7/26/2021 10:12 AM +# @File:infer_net1_onnx.py +import argparse +import time + +import matplotlib as mpl +import matplotlib.pyplot as plt +import torch +from imutils import paths +import onnxruntime # onnxruntime 1.8.1 + +from deploy.creepageDistanceModel.models_creepage import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH, postprocess + +print(onnxruntime.get_device()) + +from deploy.torch2onnx import get_image + +PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} +cmap = plt.get_cmap("jet") +norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0) +sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) +sm.set_array([]) + +def c(x): + return sm.to_rgba(x) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图") + parser.add_argument('--devices', + default=r"0", + help="没有分的文件夹") + parser.add_argument('--onnx_path', + default=r"/home/zengxh/workspace/lcnn/logs/210726-144038-88f281a-baseline/checkpoint_best.onnx", + help="没有分的文件夹") + parser.add_argument('--predict_dir', + default=r"/home/zengxh/medias/data/ext/creepageDistance/20210714/smallimg/tb/org", + help="没有分的文件夹") + parser.add_argument('--predict_type', + default=r"tb", + help="没有分的文件夹") + opt = parser.parse_args() + + options = onnxruntime.SessionOptions() + options.enable_profiling = True + ort_session = onnxruntime.InferenceSession(opt.onnx_path,options) + # ort_session.set_providers([ort_session.get_providers()[1]]) # 强制指定用CPU识别 + + image_paths = list(paths.list_images(opt.predict_dir)) + for image_path in image_paths[:10]: + im,image = get_image(image_path, opt.predict_type) + junc = torch.zeros(1, 2).cuda() + jtyp = torch.zeros(1, dtype=torch.uint8).cuda() + Lpos = torch.zeros(2, 2, dtype=torch.uint8).cuda() + + start = time.time() + lines, score = ort_session.run(['lines',"score"], {'image': image.numpy(),}) + print(time.time() - start) + + lines = lines[0] / (int(NORMALIZATION_HEIGHT / 4), int(NORMALIZATION_WIDTH / 4)) * im.shape[:2] + scores = score[0] + for i in range(1, len(lines)): + if (lines[i] == lines[0]).all(): + lines = lines[:i] + scores = scores[:i] + break + + # postprocess lines to remove overlapped lines + diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 + nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) + # for i, t in enumerate([0.01, 0.95, 0.96, 0.97, 0.98, 0.99]): + plt.gca().set_axis_off() + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + for (a, b), s in zip(nlines, nscores): + # if s < t: + # continue + plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) + plt.scatter(a[1], a[0], **PLTOPTS) + plt.scatter(b[1], b[0], **PLTOPTS) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.imshow(im) + plt.show() + plt.close() diff --git a/deploy/infer_onnx_onnxruntime_gpu.py b/deploy/infer_onnx_onnxruntime_gpu.py new file mode 100644 index 0000000..a48689b --- /dev/null +++ b/deploy/infer_onnx_onnxruntime_gpu.py @@ -0,0 +1,92 @@ +# !/usr/bin/env python +# -- coding: utf-8 -- +# @Author zengxiaohui +# Datatime:7/26/2021 10:12 AM +# @File:infer_net1_onnx.py +import argparse +import os +import time +import skimage.io +import numpy as np +import copy +import matplotlib as mpl +import matplotlib.pyplot as plt +import torch +from imutils import paths +import onnxruntime #cuda10.2==onnxruntime-gpu 1.5.2 + +from deploy.creepageDistanceModel.models_creepage import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH, postprocess +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +print(onnxruntime.get_device()) + +from deploy.torch2onnx import get_image + +PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} +cmap = plt.get_cmap("jet") +norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0) +sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) +sm.set_array([]) + +def c(x): + return sm.to_rgba(x) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图") + parser.add_argument('--devices', + default=r"0", + help="没有分的文件夹") + parser.add_argument('--onnx_path', + default=r"/home/zengxh/workspace/lcnn/logs/210726-144038-88f281a-baseline/checkpoint_best.onnx", + help="没有分的文件夹") + parser.add_argument('--predict_dir', + default=r"/home/zengxh/medias/data/ext/creepageDistance/20210714/smallimg/tb/org", + help="没有分的文件夹") + parser.add_argument('--predict_type', + default=r"tb", + help="没有分的文件夹") + opt = parser.parse_args() + + options = onnxruntime.SessionOptions() + options.enable_profiling = True + ort_session = onnxruntime.InferenceSession(opt.onnx_path,options) + + + image_paths = list(paths.list_images(opt.predict_dir)) + + for image_path in image_paths[:10]: + im_o = skimage.io.imread(image_path) + if im_o.ndim == 2: + im_o = np.repeat(im_o[:, :, None], 3, 2) + im_o = im_o[:, :, :3] + + # 第一次慢,从第二次开始快:应该是硬件从休眠状态warmup,比如cpu从低功耗低频状态提升到正常状态。 + # db适合用gpu,而angle和crnn正好相反、用CPU更快。 + image, pad, w0, h0 = get_image(image_path, opt.predict_type) + junc = torch.zeros(1, 2).cuda() + jtyp = torch.zeros(1, dtype=torch.uint8).cuda() + Lpos = torch.zeros(2, 2, dtype=torch.uint8).cuda() + + start = time.time() + nlines, nscores = ort_session.run(['lines', "score"], {'image': image.numpy(), }) + nlines[:, :, 1] = (nlines[:, :, 1] - pad[0]) * w0 / (NORMALIZATION_WIDTH - pad[0] * 2) # x + nlines[:, :, 0] = (nlines[:, :, 0] - pad[1]) * h0 / (NORMALIZATION_HEIGHT - pad[1] * 2) # y + if "tb" == opt.predict_type: + nlines=nlines[:,:,[1,0]] + nlines[:, :, 0] = w0 - nlines[:, :, 0] # y + print(time.time() - start) + + # for i, t in enumerate([0.01, 0.95, 0.96, 0.97, 0.98, 0.99]): + plt.gca().set_axis_off() + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + for (a, b), s in zip(nlines, nscores): + # if s < t: + # continue + plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) + plt.scatter(a[1], a[0], **PLTOPTS) + plt.scatter(b[1], b[0], **PLTOPTS) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.imshow(im_o) + plt.show() + plt.close() diff --git a/deploy/infer_onnx_tensorrt.py b/deploy/infer_onnx_tensorrt.py new file mode 100644 index 0000000..be84701 --- /dev/null +++ b/deploy/infer_onnx_tensorrt.py @@ -0,0 +1,41 @@ +# !/usr/bin/env python +# -- coding: utf-8 -- +# @Author zengxiaohui +# Datatime:7/26/2021 10:12 AM +# @File:infer_net1_onnx.py +import argparse +import onnx +import onnx_tensorrt.backend as backend +import torch +from imutils import paths + +from deploy.torch2onnx import get_image + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图") + parser.add_argument('--devices', + default=r"0", + help="没有分的文件夹") + parser.add_argument('--onnx_path', + default=r"/home/zengxh/workspace/lcnn/logs/210726-144038-88f281a-baseline/checkpoint_best.onnx", + help="没有分的文件夹") + parser.add_argument('--predict_dir', + default=r"/home/zengxh/medias/data/ext/creepageDistance/20210714/smallimg/tb/org", + help="没有分的文件夹") + parser.add_argument('--predict_type', + default=r"tb", + help="没有分的文件夹") + opt = parser.parse_args() + + model = onnx.load(opt.onnx_path) + engine = backend.prepare(model, device='CUDA:0') + + image_paths = list(paths.list_images(opt.predict_dir)) + image_path = image_paths[0] + image = get_image(image_path, opt.predict_type).cuda() + junc = torch.zeros(1, 2).cuda() + jtyp = torch.zeros(1, dtype=torch.uint8).cuda() + Lpos = torch.zeros(2, 2, dtype=torch.uint8).cuda() + + ret = engine.run(image) + print(ret) \ No newline at end of file diff --git a/deploy/infer_torch.py b/deploy/infer_torch.py new file mode 100644 index 0000000..54c8f86 --- /dev/null +++ b/deploy/infer_torch.py @@ -0,0 +1,129 @@ +# !/usr/bin/env python +# -- coding: utf-8 -- +# @Author zengxiaohui +# Datatime:7/26/2021 10:12 AM +# @File:infer_net1_onnx.py +import argparse +import time + +import copy +import skimage.io +import matplotlib as mpl +import matplotlib.pyplot as plt +import pprint +import torch +from imutils import paths +import numpy as np +from deploy.creepageDistanceModel.config import M, C +from deploy.creepageDistanceModel.models_creepage import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH, postprocess, hg, \ + MultitaskHead, MultitaskLearner, LineVectorizer + +from deploy.torch2onnx import get_image + +PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} +cmap = plt.get_cmap("jet") +norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0) +sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) +sm.set_array([]) + +def c(x): + return sm.to_rgba(x) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图") + parser.add_argument('--config_file', + default=r"/home/zengxh/workspace/lcnn/deploy/creepageDistanceModel/wireframe.yaml", + help="没有分的文件夹") + parser.add_argument('--checkpoint_path', + default=r"/home/zengxh/workspace/lcnn/logs/210726-144038-88f281a-baseline/checkpoint_best.pth", + help="没有分的文件夹") + parser.add_argument('--predict_dir', + default=r"/home/zengxh/medias/data/ext/creepageDistance/20210714/smallimg/tb/org", + help="没有分的文件夹") + parser.add_argument('--predict_type', + default=r"tb", + help="没有分的文件夹") + opt = parser.parse_args() + config_file = opt.config_file + C.update(C.from_yaml(filename=config_file)) + M.update(C.model) + pprint.pprint(C, indent=4) + + checkpoint = torch.load(opt.checkpoint_path, map_location='cpu') + + # Load model + model = hg( + depth=M.depth, + head=lambda c_in, c_out: MultitaskHead(c_in, c_out), + num_stacks=M.num_stacks, + num_blocks=M.num_blocks, + num_classes=sum(sum(M.head_size, [])), + ) + model = MultitaskLearner(model) + model = LineVectorizer(model) + model.load_state_dict(checkpoint["model_state_dict"]) + model = model.cuda() + model.eval() + + + image_paths = list(paths.list_images(opt.predict_dir)) + for image_path in image_paths[:10]: + im_o = skimage.io.imread(image_path) + if im_o.ndim == 2: + im_o = np.repeat(im_o[:, :, None], 3, 2) + im_o = im_o[:, :, :3] + + image,pad,w0,h0 = get_image(image_path, opt.predict_type) + junc = torch.zeros(1, 2).cuda() + jtyp = torch.zeros(1, dtype=torch.uint8).cuda() + Lpos = torch.zeros(2, 2, dtype=torch.uint8).cuda() + + start = time.time() + lines, score = model(image.cuda(),junc,jtyp,Lpos) + + lines = lines[0].cpu().numpy() / (int(NORMALIZATION_HEIGHT / 4), int(NORMALIZATION_WIDTH / 4)) * (NORMALIZATION_HEIGHT,NORMALIZATION_WIDTH) + scores = score[0].detach().cpu().numpy() + for i in range(1, len(lines)): + if (lines[i] == lines[0]).all(): + lines = lines[:i] + scores = scores[:i] + break + + # postprocess lines to remove overlapped lines + diag = (NORMALIZATION_HEIGHT ** 2 + NORMALIZATION_WIDTH ** 2) ** 0.5 + nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) + + if len(nscores) < 2: + continue + elif len(nscores) == 2: + pass + else: + middle_x = (np.max(nlines[:, :, 1]) - np.min(nlines[:, :, 1])) / 2 + np.min(nlines[:, :, 1]) + x_mean =np.mean(nlines[:, :, 1], axis=1) + line1 = nlines[np.where(x_mean > middle_x)][np.argmax(nscores[np.where(x_mean > middle_x)])] + line2 = nlines[np.where(x_mean < middle_x)][np.argmax(nscores[np.where(x_mean < middle_x)])] + nlines=np.concatenate((line1[None, :, :], line2[None, :, :]), axis=0) + + nlines[:, :, 1] = (nlines[:, :, 1] - pad[0]) * w0 / (NORMALIZATION_WIDTH - pad[0] * 2) # x + nlines[:, :, 0] = (nlines[:, :, 0] - pad[1]) * h0 / (NORMALIZATION_HEIGHT - pad[1] * 2) # y + result_lines = copy.deepcopy(nlines) + if "tb" == opt.predict_type: + result_lines[:, :, 1] = nlines[:, :, 0] # x + result_lines[:, :, 0] = w0 - nlines[:, :, 1] # y + + print(time.time()-start) + # for i, t in enumerate([0.01, 0.95, 0.96, 0.97, 0.98, 0.99]): + plt.gca().set_axis_off() + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + for (a, b), s in zip(result_lines, nscores[:2]): + # if s < t: + # continue + plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) + plt.scatter(a[1], a[0], **PLTOPTS) + plt.scatter(b[1], b[0], **PLTOPTS) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.imshow(im_o) + plt.show() + plt.close() \ No newline at end of file diff --git a/deploy/torch2onnx.py b/deploy/torch2onnx.py new file mode 100644 index 0000000..75b5654 --- /dev/null +++ b/deploy/torch2onnx.py @@ -0,0 +1,231 @@ +# !/usr/bin/env python +# -- coding: utf-8 -- +# @Author zengxiaohui +# Datatime:7/13/2021 1:06 PM +# @File:预测图片并且生成json标准文件 +import argparse +import copy +import cv2 +import os +import pprint +import random +import numpy as np +import onnx +import skimage +import torch +from imutils import paths +import torch.nn as nn + +from dataset.constants import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH, TB_DATATYPE +from deploy.creepageDistanceModel.config import M, C +from deploy.creepageDistanceModel.models_creepage import MultitaskHead, MultitaskLearner, LineVectorizer, hg +from python_developer_tools.cv.utils.torch_utils import recursive_to, init_seeds, init_cudnn +from python_developer_tools.cv.datasets.datasets_utils import letterbox + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +def get_image(image_path, datatype): + img = cv2.imread(image_path) + # origin_img = copy.deepcopy(img) + if "tb" == datatype: + img = cv2.transpose(img) + img = cv2.flip(img, 1) + h0, w0 = img.shape[:2] # orig hw + im, ratio, pad = letterbox(img, [NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH], auto=False, scaleFill=True) + cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + if im.ndim == 2: + im = np.repeat(im[:, :, None], 3, 2) + im = im[:, :, :3] + im_resized = skimage.transform.resize(im, (NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH)) * 255 + image = (im_resized - [109.730, 103.832, 98.681]) / [22.275, 22.124, 23.229] + image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float() + return image, pad, w0, h0 + + +class mymodel(nn.Module): + def __init__(self, ckpt_path): + super(mymodel, self).__init__() + checkpoint = torch.load(ckpt_path, map_location='cpu') + + # Load model + model = hg( + depth=M.depth, + head=lambda c_in, c_out: MultitaskHead(c_in, c_out), + num_stacks=M.num_stacks, + num_blocks=M.num_blocks, + num_classes=sum(sum(M.head_size, [])), + ) + model = MultitaskLearner(model) + model = LineVectorizer(model) + model.load_state_dict(checkpoint["model_state_dict"]) + model = model.cuda() + model.eval() + + self.model = model + + self.lines_tensor = torch.tensor([NORMALIZATION_HEIGHT / 4, NORMALIZATION_WIDTH / 4]).cuda() + self.le9_tensor = torch.tensor(1e-9).cuda() + self.zero_tensor = torch.tensor(0).cuda() + self.one_tensor = torch.tensor(1).cuda() + self.ten_tensor = torch.tensor(10).cuda() + + def pline(self, x1, y1, x2, y2, x, y): + px = x2 - x1 + py = y2 - y1 + dd = px * px + py * py + u = ((x - x1) * px + (y - y1) * py) / torch.max(self.le9_tensor,dd) + dx = x1 + u * px - x + dy = y1 + u * py - y + return dx * dx + dy * dy + + def plambda(self, x1, y1, x2, y2, x, y): + px = x2 - x1 + py = y2 - y1 + dd = px * px + py * py + return ((x - x1) * px + (y - y1) * py) / torch.max(self.le9_tensor,dd) + + def postprocess(self, lines, scores, threshold=0.01, tol=1e9, do_clip=False): + nlines, nscores = [], [] + for (p, q), score in zip(lines, scores): + start, end = 0, 1 + for a, b in nlines: + if ( + torch.min( + torch.max(self.pline(*p, *q, *a), self.pline(*p, *q, *b)), + torch.max(self.pline(*a, *b, *p), self.pline(*a, *b, *q)), + ) + > threshold ** 2 + ): + continue + lambda_a = self.plambda(*p, *q, *a) + lambda_b = self.plambda(*p, *q, *b) + if lambda_a > lambda_b: + lambda_a, lambda_b = lambda_b, lambda_a + lambda_a -= tol + lambda_b += tol + + # case 1: skip (if not do_clip) + if start < lambda_a and lambda_b < end: + continue + + # not intersect + if lambda_b < start or lambda_a > end: + continue + + # cover + if lambda_a <= start and end <= lambda_b: + start = 10 + break + + # case 2 & 3: + if lambda_a <= start and start <= lambda_b: + start = lambda_b + if lambda_a <= end and end <= lambda_b: + end = lambda_a + + if start >= end: + break + + if start >= end: + continue + + t_cat = torch.cat(((p + (q - p) * start).view(1, 2), (p + (q - p) * end).view(1, 2)), 0) + nlines.append(t_cat) + nscores.append(score) + nlines = torch.cat([nline.unsqueeze(0) for nline in nlines]) + nscores = torch.tensor(nscores) + return nlines, nscores + + + def forward(self, image, junc, jtyp, Lpos): + lines, score = self.model(image, junc, jtyp, Lpos) + lines = torch.div(lines[0], torch.tensor([NORMALIZATION_HEIGHT / 4, NORMALIZATION_WIDTH / 4]).cuda()) + lines = torch.mul(lines, torch.tensor([NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH]).cuda()) + scores = score[0] + len_lines = lines.shape[0] + for i in range(1, len_lines): + if torch.equal(lines[i], lines[0]): + lines = lines[:i] + scores = scores[:i] + break + + # postprocess lines to remove overlapped lines + diag = (NORMALIZATION_HEIGHT ** 2 + NORMALIZATION_WIDTH ** 2) ** 0.5 + nlines, nscores = self.postprocess(lines, scores, diag * 0.01, 0, False) + + if len(nscores) > 2: + xnlines = nlines[:, :, 1] + + # nonzeroindex = torch.nonzero(xnlines) + nlines_min = torch.min( xnlines ) #torch.min( xnlines[nonzeroindex] ) + middle_x = (torch.max(xnlines) - nlines_min) / 2 + nlines_min + x_mean = torch.mean(xnlines, axis=1) + + xmeangt = x_mean > middle_x + xmeanlt = x_mean < middle_x + indexline1 = torch.argmax(nscores[xmeangt]) + indexline2 = torch.argmax(nscores[xmeanlt]) + line1 = nlines[xmeangt][indexline1] + line2 = nlines[xmeanlt][indexline2] + nlines = torch.cat((line1.unsqueeze(0), line2.unsqueeze(0)), 0) + return nlines, torch.tensor([nscores[xmeangt][indexline1], + nscores[xmeanlt][indexline2]]) + return nlines, nscores + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图") + parser.add_argument('--config_file', + default=r"/home/zengxh/workspace/lcnn/deploy/creepageDistanceModel/wireframe.yaml", + help="没有分的文件夹") + parser.add_argument('--checkpoint_path', + default=r"/home/zengxh/workspace/lcnn/logs/210726-144038-88f281a-baseline/checkpoint_best.pth", + help="没有分的文件夹") + parser.add_argument('--onnx_path', + default=r"/home/zengxh/workspace/lcnn/logs/210726-144038-88f281a-baseline/checkpoint_best.onnx", + help="没有分的文件夹") + parser.add_argument('--predict_dir', + default=r"/home/zengxh/medias/data/ext/creepageDistance/20210714/smallimg/tb/org", + help="没有分的文件夹") + parser.add_argument('--predict_type', + default=r"tb", + help="没有分的文件夹") + opt = parser.parse_args() + config_file = opt.config_file + C.update(C.from_yaml(filename=config_file)) + M.update(C.model) + pprint.pprint(C, indent=4) + + init_seeds(0) + + device_name = "cuda" + init_cudnn() + print("Let's use", torch.cuda.device_count(), "GPU(s)!") + device = torch.device(device_name) + + mymodel = mymodel(opt.checkpoint_path) + + image_paths = list(paths.list_images(opt.predict_dir)) + image_path = image_paths[0] + image, pad, w0, h0 = get_image(image_path, opt.predict_type) + junc = torch.zeros(1, 2).cuda() + jtyp = torch.zeros(1, dtype=torch.uint8).cuda() + Lpos = torch.zeros(2, 2, dtype=torch.uint8).cuda() + torch.onnx.export(mymodel, (image.cuda(), junc, jtyp, Lpos), opt.onnx_path, + opset_version=11, + verbose=True, + export_params=True, # 是否导出params + do_constant_folding=True, # 是否进行常量折叠进行优化 + # dynamic_axes= { + # 'input': {0: 'image',1: 'junc',2: 'jtyp',3: 'Lpos'}, + # 'output': {0: 'lines',1: 'score'} + # }, + input_names=['image', "junc", "jtyp", "Lpos"], + output_names=['lines', "score"], ) + print("导出完成") + # 检查导出的model + onnxmodel = onnx.load(opt.onnx_path) + onnx.checker.check_model(onnxmodel) + print("检查完成") diff --git a/eval-mAPJ.py b/eval-mAPJ.py index 9c80e75..f030c92 100755 --- a/eval-mAPJ.py +++ b/eval-mAPJ.py @@ -84,7 +84,7 @@ def evaluate_wireframe(im_list, gt_list, juncs_wf): all_jc_gt = [] for i, (im_fn, gt_fn, junc_wf) in enumerate(zip(im_list, gt_list, juncs_wf)): im = cv2.imread(im_fn) - im = cv2.resize(im, (128, 128)) + im = cv2.resize(im, (int(NORMALIZATION_WIDTH / 4), int(NORMALIZATION_HEIGHT / 4))) with np.load(gt_fn) as npz: junc_gt = npz["junc"][:, :2] @@ -107,7 +107,7 @@ def evaluate_afm(im_list, gt_list, afm): afm.sort() for i, (im_fn, gt_fn, afm_fn) in enumerate(zip(im_list, gt_list, afm)): im = cv2.imread(im_fn) - im = cv2.resize(im, (128, 128)) + im = cv2.resize(im, (int(NORMALIZATION_WIDTH / 4), int(NORMALIZATION_HEIGHT / 4))) with np.load(gt_fn) as npz: junc_gt = npz["junc"][:, :2] @@ -117,8 +117,9 @@ def evaluate_afm(im_list, gt_list, afm): afm_score = -fafm["scores"] h = fafm["h"] w = fafm["w"] - afm_line[:, :, 0] *= 128 / h - afm_line[:, :, 1] *= 128 / w + + afm_line[:, :, 0] *= int(NORMALIZATION_HEIGHT / 4) / h + afm_line[:, :, 1] *= int(NORMALIZATION_WIDTH / 4) / w jun_c = [] for line, score in zip(afm_line, afm_score): @@ -143,8 +144,9 @@ def load_wf(): juncs = loadmat(mat)["junctions"] if len(juncs) == 0: continue - juncs[:, 0] *= 128 / img.shape[1] - juncs[:, 1] *= 128 / img.shape[0] + + juncs[:, 0] *= int(NORMALIZATION_WIDTH / 4) / img.shape[1] + juncs[:, 1] *= int(NORMALIZATION_HEIGHT / 4) / img.shape[0] # juncs += 0.5 for j in juncs: pts[i][tuple(j)] += 1 diff --git a/lcnn/datasets.py b/lcnn/datasets.py index f8dadff..497a65b 100644 --- a/lcnn/datasets.py +++ b/lcnn/datasets.py @@ -11,6 +11,7 @@ from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate +from dataset.constants import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH from lcnn.config import M @@ -60,8 +61,10 @@ def __getitem__(self, idx): lpre[i] = lpre[i, ::-1] ldir = lpre[:, 0, :2] - lpre[:, 1, :2] ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) + + feat_1 = (lpre[:, :, :2] / (int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4))).astype(np.float32) feat = [ - lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood, + feat_1.reshape(-1, 4) * M.use_cood, ldir * M.use_slop, lpre[:, :, 2], ] diff --git a/lcnn/models/line_vectorizer.py b/lcnn/models/line_vectorizer.py index a40e922..23973b9 100644 --- a/lcnn/models/line_vectorizer.py +++ b/lcnn/models/line_vectorizer.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.functional as F +from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT from lcnn.config import M FEATURE_DIM = 8 @@ -67,20 +68,20 @@ def forward(self, input_dict): p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5 p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY] - px, py = p[:, 0].contiguous(), p[:, 1].contiguous() - px0 = px.floor().clamp(min=0, max=127) - py0 = py.floor().clamp(min=0, max=127) - px1 = (px0 + 1).clamp(min=0, max=127) - py1 = (py0 + 1).clamp(min=0, max=127) + px, py = p[:, 1].contiguous(), p[:, 0].contiguous() + px0 = px.floor().clamp(min=0, max=int(NORMALIZATION_WIDTH / 4)-1) + py0 = py.floor().clamp(min=0, max=int(NORMALIZATION_HEIGHT / 4)-1) + px1 = (px0 + 1).clamp(min=0, max=int(NORMALIZATION_WIDTH / 4)-1) + py1 = (py0 + 1).clamp(min=0, max=int(NORMALIZATION_HEIGHT / 4)-1) px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long() # xp: [N_LINE, N_CHANNEL, N_POINT] xp = ( ( - x[i, :, px0l, py0l] * (px1 - px) * (py1 - py) - + x[i, :, px1l, py0l] * (px - px0) * (py1 - py) - + x[i, :, px0l, py1l] * (px1 - px) * (py - py0) - + x[i, :, px1l, py1l] * (px - px0) * (py - py0) + x[i, :, py0l, px0l] * (px1 - px) * (py1 - py) + + x[i, :, py0l, px1l] * (px - px0) * (py1 - py) + + x[i, :, py1l, px0l] * (px1 - px) * (py - py0) + + x[i, :, py1l, px1l] * (px - px0) * (py - py0) ) .reshape(n_channel, -1, M.n_pts0) .permute(1, 0, 2) @@ -171,8 +172,8 @@ def sample_lines(self, meta, jmap, joff, mode): # index: [N_TYPE, K] score, index = torch.topk(jmap, k=K) - y = (index / 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5 - x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5 + y = (index / int(NORMALIZATION_WIDTH / 4)).float() + torch.gather(joff[:, 0], 1, index) + 0.5 + x = (index % int(NORMALIZATION_WIDTH / 4)).float() + torch.gather(joff[:, 1], 1, index) + 0.5 # xy: [N_TYPE, K, 2] xy = torch.cat([y[..., None], x[..., None]], dim=-1) @@ -230,8 +231,8 @@ def sample_lines(self, meta, jmap, joff, mode): u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6) feat = torch.cat( [ - xyu / 128 * M.use_cood, - xyv / 128 * M.use_cood, + xyu / torch.tensor([int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4)]).to(device) * M.use_cood, + xyv / torch.tensor([int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4)]).to(device) * M.use_cood, u2v * M.use_slop, (u[:, None] > K).float(), (v[:, None] > K).float(), diff --git a/lcnn/trainer.py b/lcnn/trainer.py index ec3b8c0..49ffa08 100644 --- a/lcnn/trainer.py +++ b/lcnn/trainer.py @@ -18,6 +18,7 @@ from lcnn.config import C, M from lcnn.utils import recursive_to +from python_developer_tools.files.common import mkdir class Trainer(object): @@ -57,8 +58,9 @@ def run_tensorboard(self): os.makedirs(board_out) self.writer = SummaryWriter(board_out) os.environ["CUDA_VISIBLE_DEVICES"] = "" + mkdir(os.path.abspath(board_out)) p = subprocess.Popen( - ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"] + ["/home/zengxh/anaconda3/envs/CreepageDistance/bin/tensorboard", f"--logdir={os.path.abspath(board_out)}", f"--port={C.io.tensorboard_port}","--host=0.0.0.0"] ) def killme(): @@ -323,6 +325,7 @@ def pprint(*args): def _launch_tensorboard(board_out, port, out): os.environ["CUDA_VISIBLE_DEVICES"] = "" + mkdir(board_out) p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"]) def kill(): diff --git a/misc/draw-wireframe.py b/misc/draw-wireframe.py index 5055274..ee1daab 100755 --- a/misc/draw-wireframe.py +++ b/misc/draw-wireframe.py @@ -56,8 +56,8 @@ def draw(args): with np.load(gt_name) as fgt: gt_line = fgt["lpos"][:, :, :2] - gt_line[:, :, 0] *= img.shape[0] / 128 - gt_line[:, :, 1] *= img.shape[1] / 128 + gt_line[:, :, 0] *= img.shape[0] / int(NORMALIZATION_HEIGHT / 4) + gt_line[:, :, 1] *= img.shape[1] / int(NORMALIZATION_WIDTH / 4) with np.load(afm_name) as fafm: afm_line = fafm["lines"].reshape(-1, 2, 2)[:, :, ::-1] diff --git a/misc/plot-sAP.py b/misc/plot-sAP.py index 9fbfef9..9ba3b4e 100755 --- a/misc/plot-sAP.py +++ b/misc/plot-sAP.py @@ -50,8 +50,8 @@ def wireframe_score(T=10): for i, (gt_name, matf) in enumerate(zip(gts, mat_files)): line_pred = scipy.io.loadmat(matf)["lines"].reshape(-1, 2, 2) img = cv2.imread(matf.replace(".mat", ".jpg")) - line_pred[:, :, 0] *= 128 / img.shape[1] - line_pred[:, :, 1] *= 128 / img.shape[0] + line_pred[:, :, 0] *= int(NORMALIZATION_WIDTH / 4) / img.shape[1] + line_pred[:, :, 1] *= int(NORMALIZATION_HEIGHT / 4) / img.shape[0] line_pred = line_pred[:, :, ::-1] with np.load(gt_name) as fgt: @@ -117,8 +117,8 @@ def line_score(threshold=10): afm_score = -fafm["scores"] h = fafm["h"] w = fafm["w"] - afm_line[:, :, 0] *= 128 / h - afm_line[:, :, 1] *= 128 / w + afm_line[:, :, 0] *= int(NORMALIZATION_HEIGHT / 4) / h + afm_line[:, :, 1] *= int(NORMALIZATION_WIDTH / 4) / w for i, ((a, b), s) in enumerate(zip(lcnn_line, lcnn_score)): if i > 0 and (lcnn_line[i] == lcnn_line[0]).all(): lcnn_line = lcnn_line[:i] diff --git a/post.py b/post.py index 0e5acd4..5adca48 100755 --- a/post.py +++ b/post.py @@ -80,16 +80,16 @@ def handle(allname): scores = f["score"] with np.load(gtname) as f: gtlines = f["lpos"][:, :, :2] - gtlines[:, :, 0] *= im.shape[0] / 128 - gtlines[:, :, 1] *= im.shape[1] / 128 + gtlines[:, :, 0] *= im.shape[0] / int(NORMALIZATION_HEIGHT / 4) + gtlines[:, :, 1] *= im.shape[1] / int(NORMALIZATION_WIDTH / 4) for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break - lines[:, :, 0] *= im.shape[0] / 128 - lines[:, :, 1] *= im.shape[1] / 128 + lines[:, :, 0] *= im.shape[0] / int(NORMALIZATION_HEIGHT / 4) + lines[:, :, 1] *= im.shape[1] / int(NORMALIZATION_WIDTH / 4) diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 for threshold in thresholds: @@ -119,8 +119,8 @@ def handle(allname): npz_name.replace(".npz", f"_{i}.png"), dpi=500, bbox_inches=0 ) - nlines[:, :, 0] *= 128 / im.shape[0] - nlines[:, :, 1] *= 128 / im.shape[1] + nlines[:, :, 0] *= int(NORMALIZATION_HEIGHT / 4) / im.shape[0] + nlines[:, :, 1] *= int(NORMALIZATION_WIDTH / 4) / im.shape[1] np.savez_compressed(npz_name, lines=nlines, score=nscores) parmap(handle, inputs, 12)