Skip to content

Commit

Permalink
changed data paths to be relative and fixed pose model load
Browse files Browse the repository at this point in the history
  • Loading branch information
ajdroid committed Apr 30, 2018
1 parent 023bc86 commit e3f05e0
Showing 1 changed file with 100 additions and 70 deletions.
170 changes: 100 additions & 70 deletions retrieval-src/Retrieval-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,25 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
"ExecuteTime": {
"end_time": "2018-04-30T14:21:29.137789Z",
"start_time": "2018-04-30T14:21:28.782419Z"
},
"collapsed": true
},
"outputs": [],
"outputs": [
{
"ename": "ImportError",
"evalue": "/opt/ros/kinetic/lib/python2.7/dist-packages/cv2.so: undefined symbol: PyCObject_Type",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-a0d1d6d79481>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'matplotlib'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'inline'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolors\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcolors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mneighbors\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNearestNeighbors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mLSHForest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mImportError\u001b[0m: /opt/ros/kinetic/lib/python2.7/dist-packages/cv2.so: undefined symbol: PyCObject_Type"
]
}
],
"source": [
"from __future__ import division\n",
"import numpy as np\n",
Expand Down Expand Up @@ -100,9 +116,7 @@
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Set output layer name. You can use sketch_net.blobs.keys() to list all layer\n",
Expand All @@ -113,9 +127,7 @@
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Set the transformer\n",
Expand Down Expand Up @@ -159,16 +171,14 @@
"outputs": [],
"source": [
"#TODO: specify photo folder for the retrieval and the segmentation folder\n",
"photo_paths = '/data1/ravikiran/SketchObjPartSegmentation/data/PASCAL_Parts_select_png/'\n",
"segPaths = '/data1/ravikiran/SketchObjPartSegmentation/data/pascal_parts_GT_no_aug/'"
"photo_paths = '../exp-src/data/PASCAL_Parts_select_png/'\n",
"segPaths = '../exp-src/data/pascal_parts_GT_no_aug/'"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand All @@ -188,9 +198,7 @@
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"def sketchyFeatExt(photo_paths):\n",
Expand Down Expand Up @@ -237,18 +245,33 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {
"collapsed": true
"ExecuteTime": {
"end_time": "2018-04-30T14:22:25.786704Z",
"start_time": "2018-04-30T14:22:25.780324Z"
}
},
"outputs": [],
"outputs": [
{
"ename": "ImportError",
"evalue": "No module named 'torch'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-0db45c47e31a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdirname\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/data1/abhijat/pytorch-resnet/\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mImportError\u001b[0m: No module named 'torch'"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"import os, sys\n",
"sys.path.append(os.path.dirname(\"/data1/abhijat/pytorch-resnet/\"))\n",
"import resnet_dilated_frozen_r5_D\n",
"sys.path.append(os.path.dirname(\"../exp-src/\"))\n",
"import resnet_dilated_frozen_r5_D_pose\n",
"gpu0 = 0\n",
"import scipy\n",
"from scipy import ndimage"
Expand Down Expand Up @@ -1593,17 +1616,15 @@
}
],
"source": [
"model = getattr(resnet_dilated_frozen_r5_D, \"Res_Deeplab\")()\n",
"model = getattr(resnet_dilated_frozen_r5_D_pose, \"Res_Deeplab\")()\n",
"model.eval()\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand All @@ -1614,7 +1635,8 @@
}
],
"source": [
"!ls /data1/ravikiran/pytorch-resnet/snapshots/model_r5_p50x_D1_17000.pth"
"# Assuming that you downloaded the model file and put it in sketch-parse/retrieval-src/\n",
"!ls model_r5_p50x_D1_17000.pth"
]
},
{
Expand All @@ -1626,17 +1648,15 @@
"outputs": [],
"source": [
"# Path to sketch segmentation model goes here\n",
"model_path = '/data1/ravikiran/pytorch-resnet/snapshots/model_r5_p50x_D1_17000.pth' \n",
"model_path = 'model_r5_p50x_D1_17000.pth' \n",
"saved_state_dict = torch.load(model_path)\n",
"model.load_state_dict(saved_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"def merge_parts(map_, i):\n",
Expand Down Expand Up @@ -1664,13 +1684,11 @@
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"def imInNet(imName, classLabel):\n",
" sketch_root = '/data1/ravikiran/SketchObjPartSegmentation/data/temp_annotation_processor/SVG/PNG_untouched/'\n",
" sketch_root = '../exp-src/data/temp_annotation_processor/SVG/PNG_untouched/'\n",
"# sketch_root = '/data1/ravikiran/SketchObjPartSegmentation/'\n",
" imName = os.path.join(sketch_root, classLabel, imName)\n",
" img = cv2.imread(imName)\n",
Expand All @@ -1685,9 +1703,7 @@
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"def sketchPTseg(sketchName, classLabel, selector=2):\n",
Expand All @@ -1710,7 +1726,6 @@
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
Expand Down Expand Up @@ -1741,7 +1756,6 @@
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
Expand Down Expand Up @@ -1804,9 +1818,7 @@
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Start and connect to MATLAB\n",
Expand Down Expand Up @@ -1850,9 +1862,7 @@
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"oneHotMap = {}\n",
Expand All @@ -1863,13 +1873,11 @@
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"im2class = {}\n",
"listDir = '/data1/ravikiran/SketchObjPartSegmentation/data/lists/train_val_lists/'\n",
"listDir = '../exp-src/data/lists/train_val_lists/'\n",
"for classLabel, classIdx in class2idx.items():\n",
" imListFile = os.path.join(listDir, 'chosen_train_'+ classLabel +'_list.txt')\n",
" with open(imListFile, 'r') as f:\n",
Expand All @@ -1883,9 +1891,7 @@
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Part info\n",
Expand Down Expand Up @@ -1945,9 +1951,7 @@
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"class2idx = {}\n",
Expand Down Expand Up @@ -2422,8 +2426,8 @@
" GQ = sketch2graph(querySegmap, classLabel=classLabel)\n",
" \n",
" # Retrieval image meta-data (PASCAL)\n",
"# gtDir = '/data1/ravikiran/SketchObjPartSegmentation/data/pascal_parts_GT_no_aug/'\n",
"# imDir = '/data1/ravikiran/SketchObjPartSegmentation/data/PASCAL_Parts_select_png/'\n",
"# gtDir = '../exp-src/data/pascal_parts_GT_no_aug/'\n",
"# imDir = '../exp-src/data/PASCAL_Parts_select_png/'\n",
" gtDir = segPaths\n",
" imDir = photo_paths\n",
" \n",
Expand All @@ -2448,22 +2452,20 @@
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# rate reranking on PASCAL\n",
"def rateReranking(querySketchPath, display='False', selector=2):\n",
" sketch_path = querySketchPath\n",
" classLabel = os.path.basename(os.path.dirname(sketch_path))\n",
" \n",
" vis_dir = os.path.join('/data1/ravikiran/SketchObjPartSegmentation/data/PASCAL_Parts_select_png/')\n",
" vis_dir = os.path.join('../exp-src/data/PASCAL_Parts_select_png/')\n",
" ret_dir = vis_dir\n",
" seg_dir = os.path.join('/data1/ravikiran/SketchObjPartSegmentation/data/pascal_parts_GT_no_aug/')\n",
" seg_dir = os.path.join('../exp-src/data/pascal_parts_GT_no_aug/')\n",
" \n",
" sketch_seg_dir = os.path.join('/data1/ravikiran/SketchObjPartSegmentation/data/temp_annotation_processor/SVG/PNG_untouched/', classLabel)\n",
" seg_path = os.path.join( '/data1/ravikiran/SketchObjPartSegmentation/data/temp_annotation_processor/SVG/PNG_untouched/', classLabel, os.path.basename(sketch_path))\n",
" sketch_seg_dir = os.path.join('../exp-src/data/temp_annotation_processor/SVG/PNG_untouched/', classLabel)\n",
" seg_path = os.path.join( '../exp-src/data/temp_annotation_processor/SVG/PNG_untouched/', classLabel, os.path.basename(sketch_path))\n",
" \n",
" feats, ret_list = sketchyFeatExt(ret_dir)\n",
" \n",
Expand Down Expand Up @@ -2539,7 +2541,6 @@
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
Expand Down Expand Up @@ -2627,7 +2628,7 @@
"source": [
"# getHere\n",
"rateReranking(\\\n",
"'/data1/ravikiran/SketchObjPartSegmentation/data/temp_annotation_processor/SVG/PNG_untouched/dog/n02106662_24019-5.png'\\\n",
"'../exp-src/data/temp_annotation_processor/SVG/PNG_untouched/dog/n02106662_24019-5.png'\\\n",
" , True, 1)"
]
},
Expand Down Expand Up @@ -2659,16 +2660,45 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 1
}

0 comments on commit e3f05e0

Please sign in to comment.