diff --git a/code/badja_data.py b/code/badja_data.py index c15bd33..f276591 100644 --- a/code/badja_data.py +++ b/code/badja_data.py @@ -6,6 +6,7 @@ import numpy as np import scipy.misc from random import shuffle +import random import json import os import cv2 @@ -13,7 +14,9 @@ from joint_catalog import SMALJointInfo -BADJA_PATH = os.path.dirname(sys.path[0]) # Assumes you are exectuting from "BADJA" root +print(sys.path) +BADJA_PATH = os.path.dirname('F:\\Datasets\\BADJA\\') # Assumes you are exectuting from "BADJA" root + IGNORE_ANIMALS = [ # "bear.json", # "camel.json", @@ -60,27 +63,34 @@ def __init__(self): print ("BADJA IMAGE file path: {0} is missing".format(file_name)) self.animal_dict[animal_id] = (filenames, segnames, joints, visible) - print ("Loaded BADJA dataset") - + + def get_loader(self): - for idx in xrange(int(1e6)): - animal_id = np.random.choice(self.animal_dict.keys()) - filenames, segnames, joints, visible = self.animal_dict[animal_id] + for idx in range(int(1e6)): + print(self.animal_dict.keys()) + animal_id = random.choice(list(self.animal_dict.keys())) + print("Animal ID: ", animal_id) + if animal_id == 2 or animal_id == 10: + print("This is either a cat or tiger iamge") + else: + filenames, segnames, joints, visible = self.animal_dict[animal_id] + print(len(filenames)) + #print(filenames) - image_id = np.random.randint(0, len(filenames)) + image_id = np.random.randint(0, len(filenames)) - seg_file = segnames[image_id] - image_file = filenames[image_id] + seg_file = segnames[image_id] + image_file = filenames[image_id] - joints = joints[image_id].copy() - joints = joints[self.smal_joint_info.annotated_classes] - visible = visible[image_id][self.smal_joint_info.annotated_classes] + joints = joints[image_id].copy() + joints = joints[self.smal_joint_info.annotated_classes] + visible = visible[image_id][self.smal_joint_info.annotated_classes] - rgb_img = scipy.misc.imread(image_file, mode='RGB') - sil_img = scipy.misc.imread(seg_file, mode='RGB') + rgb_img = scipy.misc.imread(image_file, mode='RGB') + sil_img = scipy.misc.imread(seg_file, mode='RGB') - rgb_h, rgb_w, _ = rgb_img.shape - sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) + rgb_h, rgb_w, _ = rgb_img.shape + sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) - yield rgb_img, sil_img, joints, visible, image_file \ No newline at end of file + yield rgb_img, sil_img, joints, visible, image_file diff --git a/code/view_badja.py b/code/view_badja.py index 86a7f2e..dc42e9b 100644 --- a/code/view_badja.py +++ b/code/view_badja.py @@ -15,10 +15,14 @@ def draw_joints_on_image(rgb_img, joints, visibility, region_colors, marker_types): joints = joints[:, ::-1] # OpenCV works in (x, y) rather than (i, j) - disp_img = rgb_img.copy() + disp_img = rgb_img.copy() for joint_coord, visible, color, marker_type in zip(joints, visibility, region_colors, marker_types): if visible: joint_coord = joint_coord.astype(int) + #print("Colours: ", color) + #print(marker_types) + color = [int(i) for i in color] + #print("Colour Value: ", int(color)) cv2.drawMarker(disp_img, tuple(joint_coord), color, marker_type, 30, thickness = 10) return disp_img @@ -29,6 +33,8 @@ def main(): data_loader = badja_data.get_loader() for rgb_img, sil_img, joints, visible, name in data_loader: + #print(visible) + #print(smal_joint_info.joint_colors) rgb_vis = draw_joints_on_image(rgb_img, joints, visible, smal_joint_info.joint_colors, smal_joint_info.annotated_markers) sil_vis = draw_joints_on_image(sil_img, joints, visible, smal_joint_info.joint_colors, smal_joint_info.annotated_markers) @@ -41,4 +47,4 @@ def main(): plt.pause(0.5) if __name__ == '__main__': - main() \ No newline at end of file + main()