Skip to content

Commit

Permalink
added graph with correct scores
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanmalod committed Mar 5, 2018
1 parent ce7406b commit 8d3b210
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 135 deletions.
18 changes: 9 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from splitter import pieces, edges, compatibility, print_results, reconstruct, graph_init, result_cleaner, graph_create, print_graph, get_best_match
from splitter import pieces, edges, compatibility, graph_create, print_graph, get_best_match, is_best_match, reconstruct
from PIL import Image
import math
import scipy.spatial.distance as SSD
Expand All @@ -15,7 +15,7 @@
def show_pieces(np_array):
x= 0
y = 0
result = Image.new('RGB', (4839,3198))
result = Image.new('RGB', (1000,668))
print(len(np_array))
for i in range(0,len(np_array)):
x=0
Expand Down Expand Up @@ -50,13 +50,13 @@ def jpg_image_to_array(image_path):
im_arr = im_arr.reshape((image.size[1], image.size[0], 3))
return im_arr

img = jpg_image_to_array('forest.jpeg')
np_image = pieces(img,10,10)
img = jpg_image_to_array('fez.jpg')
np_image = pieces(img,5,5)
show_pieces(np_image)
result_edge = edges(np_image,4)
results = compatibility(result_edge)
print(get_best_match(0,0,results, 1))
#dico = result_cleaner(results)
#graphres = graph_create(dico)

#print_graph(graphres)
graphres = graph_create(results)
print(is_best_match(0,0,0,1,results,3))
res_place =reconstruct(graphres, results)
print(res_place[0][0][1])
print_graph(graphres)
192 changes: 66 additions & 126 deletions splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def pieces(image, width_num, height_num):
split_width = width // width_num
y=0
i = 0
#Image.fromarray(image).show()
for i in range(0, height_num):
x = 0
for j in range(0, width_num):
Expand Down Expand Up @@ -62,7 +61,7 @@ def edges(np_image, target) :
def compatibility(np_pieces):
shape_results = (len(np_pieces),len(np_pieces[0]),4,len(np_pieces),len(np_pieces[0]))
results = np.empty(shape_results)
results[:] = 1000
results.fill(np.nan)
for x in range(0, len(np_pieces)):
for y in range(0,len(np_pieces[x])):
print('Piece ' + ' ' + str(x) + ' ' + str(y) + ' finished processing...')
Expand All @@ -72,8 +71,6 @@ def compatibility(np_pieces):
if x != xx or y != yy:
zz = getInverse(z)
results[x][y][z][xx][yy] = get_score(np_pieces[x][y][z],np_pieces[xx][yy][zz],z)
#for x, elem in enumerate(results):
#results[x].sort()
return results

#Return the relevant edge position given a position
Expand All @@ -89,9 +86,11 @@ def getInverse(p):

def get_best_match(piece_X, piece_Y, results, edge):
piece_result = results[piece_X][piece_Y][edge]
print(piece_result)
print(piece_result.argmin())
return np.unravel_index(piece_result.argmin(), piece_result.shape)
idx = np.argsort(piece_result.ravel())[:2]
x, y = np.unravel_index(idx[0], piece_result.shape)
xx, yy = np.unravel_index(idx[1], piece_result.shape)
score = piece_result[x][y] / piece_result[xx][yy]
return (x,y,score)

def get_score(piece1, piece2, edge):
piece1 = matplotlib.colors.rgb_to_hsv(piece1 / float(256))
Expand Down Expand Up @@ -142,134 +141,75 @@ def get_score(piece1, piece2, edge):

return(mahalanobis_distp1p2 + mahalanobis_distp2p1)

def print_results(results):
piece_num = 0
results_matrix = np.full((len(results), 16), -1, dtype='float')
for x, result in enumerate(results):
edge_num = x % 4
if x % 4 == 0 and x != 0:
piece_num += 1
print('Piece number ' + str(piece_num) + ' edge number ' + str(edge_num))
print(np.array(result))
for x in range(0, len(results)):
for y in range(0, len(results[x])):
edge = results[x][y][1]
score = results[x][y][0]
results_matrix[x,edge] = score
print(results_matrix)


"""
Determine if two piece are best matches to one another (Best buddy)
@return True if best buddy else False
"""
def is_best_match(x,y,xx,yy, results, edge):
x_match,y_match = get_best_match(x,y,results,edge)[:2]
return (x_match,y_match) == (xx,yy)

def print_graph(graphs):
for graph in graphs:
pos = nx.spring_layout(graph)
nx.draw_networkx_nodes(graph, pos, node_size=600)
nx.draw_networkx_edges(graph, pos, edge_color='black', width=3)
labels = nx.get_edge_attributes(graph, 'weight')
nx.draw_networkx_labels(graph, pos, font_size=20, font_family='sans-serif')
nx.draw_networkx_edge_labels(graph, pos, edge_labels=labels)
plt.axis('off')
print(graph.adj.items())
plt.show()

def reconstruct(results, np_pieces):
np_pieces = np.array(np_pieces)
total = np_pieces.shape[0]
h = np_pieces[0].shape[0]
w = np_pieces[0].shape[1]
target_width = math.sqrt(total) * w
target_height = math.sqrt(total) * h
x = 0
y = 0
result = Image.new('RGB', (int(target_width), int(target_height)))
img = Image.fromarray(np_pieces[0])
w, h = img.size
result.paste(img, (x, y))
for x in range(0, len(results)):
score, piece = results[x][0]
position = piece % 4
piece = int(piece / 4)
print(str(piece) + ' ' + str(piece))
target_image = np_pieces[piece]
target_image = Image.fromarray(target_image)
if position == 3 and piece == 2:
result.paste(target_image, (w, 0))
if position == 0 and piece == 1:
result.paste(target_image, (0, h))
if position == 0and piece == 3:
result.paste(target_image, (w, h))
result.show()

# Remove any result < 10^-6 and equal to nan
# Return the smallest result divided by the second smallest in order to discern good results from inconclusive ones
def result_cleaner(results):
r_shape = results.shape
d = defaultdict(lambda: defaultdict(dict))
for x in range(0, len(results)):
for y in range(0, len(results[x])):
def reconstruct(graphs, results):
res_shape = results.shape
h = res_shape[0]
w = res_shape[1]
solution = [[['#' for _ in range(4)] for _ in range(res_shape[1])] for _ in range(res_shape[0])]
for graph in graphs:
old_w = 1000000
for (u, v, wt) in graph.edges.data('weight'):
edge = graph.get_edge_data(u, v)['edge']
print('edge ',edge)
if is_best_match(u[0],u[1],v[0],v[1],results,edge):
solution[u[0]][u[1]][edge] = v
solution[v[0]][v[1]][edge] = u
return solution

def display_results(results,images):
img_final = np.zeroes((1000,668,3))
for x in range(results):
for y in range(results):
for z in range(4):
s = results[x][y][z].shape
res = np.argsort(np.ravel(results[x][y][z]))[:2]
res_form = np.unravel_index(res, s)
lX = res_form[0][0]
lY = res_form[0][1]
lXX = res_form[1][0]
lYY = res_form[1][1]
d[x][y][z] = (results[x][y][z][lX][lY]/results[x][y][z][lXX][lYY], lX, lY)
return d

def print_graph(graph):
pos = nx.spring_layout(graph)
nx.draw_networkx_nodes(graph, pos, node_size=300)
nx.draw_networkx_edges(graph, pos, edge_color='black', width=3)
labels = nx.get_edge_attributes(graph, 'weight')
print(labels)
nx.draw_networkx_labels(graph, pos, font_size=20, font_family='sans-serif')
nx.draw_networkx_edge_labels(graph, pos, edge_labels=labels)
plt.axis('off')
plt.show()




if results[x][y][z] != 0:
target_x = results[x][y][z][0]
target_Y = results[x][y][z][1]
if z == 0:
insert_image(img_final, images[x][y], edge)

if z == 1:
insert_image
if z == 2:
insert_image
if z ==3:
insert_image
results[target_x][target_Y][z] = 0

def graph_create(results):
graphs = []
graph = nx.Graph()
for x in range(0, len(results)):
for y in range(0, len(results[x])):
graph = nx.Graph()
for z in range(4):
graph.add_edge((x,y), (results[x][y][z][1:]), weight=results[x][y][z][0])
print(nx.shortest_path(graph, source=(0,0), target=(1,1)))
return graph

match_x, match_y, match_score = get_best_match(x,y,results,z)
# Check if edge already exists, if it does only replace current edge if match_score is smaller than existing score (this can happen at edges)
if graph.get_edge_data((x,y),(match_x,match_y)) is not None and graph.get_edge_data((x,y),(match_x,match_y))['weight'] < match_score:
print('Edge already exists and has a lower score for piece', x, y)
else:
graph.add_edge((x,y), (match_x,match_y), weight=round(match_score,5), edge=z)
graphs.append(graph)
return graphs


def graph_init(results):
graphs = []
nodes = [x for x in range(0,len(results))]
count = 0
piece_num = 0
graph = nx.Graph()
start = 0
end = 10
for y in range(0, len(results)):
graph = nx.Graph()
start = y*4
end = start + 4
for x in range(start, end):
print(x)
score, position = results[x]
piece_target = position // 4
edge_target = position % 4
piece_num = y
edge_num = x
print('Piece number ' + str(piece_num) + ' ' + ' to piece num ' + str(piece_target))
graph.add_edge(piece_num, piece_target, weight=score, object={edge_num:edge_target})
graphs.append(graph)
for graph in graphs:
pos = nx.spring_layout(graph)
#e1 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] == 0]
#e2 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] == 1]
#e3 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] ==2]
#e4 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] == 3]
nx.draw_networkx_nodes(graph, pos, node_size=300)
nx.draw_networkx_edges(graph, pos,edge_color='b',width=6)
#nx.draw_networkx_edges(graph, pos,edgelist=e1, edge_color='b',width=6)
#nx.draw_networkx_edges(graph, pos, edgelist=e2,edge_color='r',width=6)
#nx.draw_networkx_edges(graph, pos, edgelist=e3,edge_color='g',width=6)
#nx.draw_networkx_edges(graph, pos, edgelist=e4,edge_color='y',width=6)
labels = nx.get_edge_attributes(graph, 'weight')
labels2 = nx.get_edge_attributes(graph, 'object')
nx.draw_networkx_labels(graph, pos, font_size=20, font_family='sans-serif')
nx.draw_networkx_edge_labels(graph, pos, edge_labels=labels)
plt.axis('off')
plt.show()

0 comments on commit 8d3b210

Please sign in to comment.