From 8d3b210973b4569d4e63555b17675c2ddc2522e6 Mon Sep 17 00:00:00 2001 From: jeanmalod Date: Mon, 5 Mar 2018 21:37:23 +0000 Subject: [PATCH] added graph with correct scores --- main.py | 18 ++--- splitter.py | 192 ++++++++++++++++++---------------------------------- 2 files changed, 75 insertions(+), 135 deletions(-) diff --git a/main.py b/main.py index e4f59e4..5662958 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from splitter import pieces, edges, compatibility, print_results, reconstruct, graph_init, result_cleaner, graph_create, print_graph, get_best_match +from splitter import pieces, edges, compatibility, graph_create, print_graph, get_best_match, is_best_match, reconstruct from PIL import Image import math import scipy.spatial.distance as SSD @@ -15,7 +15,7 @@ def show_pieces(np_array): x= 0 y = 0 - result = Image.new('RGB', (4839,3198)) + result = Image.new('RGB', (1000,668)) print(len(np_array)) for i in range(0,len(np_array)): x=0 @@ -50,13 +50,13 @@ def jpg_image_to_array(image_path): im_arr = im_arr.reshape((image.size[1], image.size[0], 3)) return im_arr -img = jpg_image_to_array('forest.jpeg') -np_image = pieces(img,10,10) +img = jpg_image_to_array('fez.jpg') +np_image = pieces(img,5,5) show_pieces(np_image) result_edge = edges(np_image,4) results = compatibility(result_edge) -print(get_best_match(0,0,results, 1)) -#dico = result_cleaner(results) -#graphres = graph_create(dico) - -#print_graph(graphres) +graphres = graph_create(results) +print(is_best_match(0,0,0,1,results,3)) +res_place =reconstruct(graphres, results) +print(res_place[0][0][1]) +print_graph(graphres) diff --git a/splitter.py b/splitter.py index e7b0ba9..f64c925 100644 --- a/splitter.py +++ b/splitter.py @@ -30,7 +30,6 @@ def pieces(image, width_num, height_num): split_width = width // width_num y=0 i = 0 - #Image.fromarray(image).show() for i in range(0, height_num): x = 0 for j in range(0, width_num): @@ -62,7 +61,7 @@ def edges(np_image, target) : def compatibility(np_pieces): shape_results = (len(np_pieces),len(np_pieces[0]),4,len(np_pieces),len(np_pieces[0])) results = np.empty(shape_results) - results[:] = 1000 + results.fill(np.nan) for x in range(0, len(np_pieces)): for y in range(0,len(np_pieces[x])): print('Piece ' + ' ' + str(x) + ' ' + str(y) + ' finished processing...') @@ -72,8 +71,6 @@ def compatibility(np_pieces): if x != xx or y != yy: zz = getInverse(z) results[x][y][z][xx][yy] = get_score(np_pieces[x][y][z],np_pieces[xx][yy][zz],z) - #for x, elem in enumerate(results): - #results[x].sort() return results #Return the relevant edge position given a position @@ -89,9 +86,11 @@ def getInverse(p): def get_best_match(piece_X, piece_Y, results, edge): piece_result = results[piece_X][piece_Y][edge] - print(piece_result) - print(piece_result.argmin()) - return np.unravel_index(piece_result.argmin(), piece_result.shape) + idx = np.argsort(piece_result.ravel())[:2] + x, y = np.unravel_index(idx[0], piece_result.shape) + xx, yy = np.unravel_index(idx[1], piece_result.shape) + score = piece_result[x][y] / piece_result[xx][yy] + return (x,y,score) def get_score(piece1, piece2, edge): piece1 = matplotlib.colors.rgb_to_hsv(piece1 / float(256)) @@ -142,134 +141,75 @@ def get_score(piece1, piece2, edge): return(mahalanobis_distp1p2 + mahalanobis_distp2p1) -def print_results(results): - piece_num = 0 - results_matrix = np.full((len(results), 16), -1, dtype='float') - for x, result in enumerate(results): - edge_num = x % 4 - if x % 4 == 0 and x != 0: - piece_num += 1 - print('Piece number ' + str(piece_num) + ' edge number ' + str(edge_num)) - print(np.array(result)) - for x in range(0, len(results)): - for y in range(0, len(results[x])): - edge = results[x][y][1] - score = results[x][y][0] - results_matrix[x,edge] = score - print(results_matrix) - - +""" +Determine if two piece are best matches to one another (Best buddy) +@return True if best buddy else False +""" +def is_best_match(x,y,xx,yy, results, edge): + x_match,y_match = get_best_match(x,y,results,edge)[:2] + return (x_match,y_match) == (xx,yy) +def print_graph(graphs): + for graph in graphs: + pos = nx.spring_layout(graph) + nx.draw_networkx_nodes(graph, pos, node_size=600) + nx.draw_networkx_edges(graph, pos, edge_color='black', width=3) + labels = nx.get_edge_attributes(graph, 'weight') + nx.draw_networkx_labels(graph, pos, font_size=20, font_family='sans-serif') + nx.draw_networkx_edge_labels(graph, pos, edge_labels=labels) + plt.axis('off') + print(graph.adj.items()) + plt.show() -def reconstruct(results, np_pieces): - np_pieces = np.array(np_pieces) - total = np_pieces.shape[0] - h = np_pieces[0].shape[0] - w = np_pieces[0].shape[1] - target_width = math.sqrt(total) * w - target_height = math.sqrt(total) * h - x = 0 - y = 0 - result = Image.new('RGB', (int(target_width), int(target_height))) - img = Image.fromarray(np_pieces[0]) - w, h = img.size - result.paste(img, (x, y)) - for x in range(0, len(results)): - score, piece = results[x][0] - position = piece % 4 - piece = int(piece / 4) - print(str(piece) + ' ' + str(piece)) - target_image = np_pieces[piece] - target_image = Image.fromarray(target_image) - if position == 3 and piece == 2: - result.paste(target_image, (w, 0)) - if position == 0 and piece == 1: - result.paste(target_image, (0, h)) - if position == 0and piece == 3: - result.paste(target_image, (w, h)) - result.show() -# Remove any result < 10^-6 and equal to nan -# Return the smallest result divided by the second smallest in order to discern good results from inconclusive ones -def result_cleaner(results): - r_shape = results.shape - d = defaultdict(lambda: defaultdict(dict)) - for x in range(0, len(results)): - for y in range(0, len(results[x])): +def reconstruct(graphs, results): + res_shape = results.shape + h = res_shape[0] + w = res_shape[1] + solution = [[['#' for _ in range(4)] for _ in range(res_shape[1])] for _ in range(res_shape[0])] + for graph in graphs: + old_w = 1000000 + for (u, v, wt) in graph.edges.data('weight'): + edge = graph.get_edge_data(u, v)['edge'] + print('edge ',edge) + if is_best_match(u[0],u[1],v[0],v[1],results,edge): + solution[u[0]][u[1]][edge] = v + solution[v[0]][v[1]][edge] = u + return solution + +def display_results(results,images): + img_final = np.zeroes((1000,668,3)) + for x in range(results): + for y in range(results): for z in range(4): - s = results[x][y][z].shape - res = np.argsort(np.ravel(results[x][y][z]))[:2] - res_form = np.unravel_index(res, s) - lX = res_form[0][0] - lY = res_form[0][1] - lXX = res_form[1][0] - lYY = res_form[1][1] - d[x][y][z] = (results[x][y][z][lX][lY]/results[x][y][z][lXX][lYY], lX, lY) - return d - -def print_graph(graph): - pos = nx.spring_layout(graph) - nx.draw_networkx_nodes(graph, pos, node_size=300) - nx.draw_networkx_edges(graph, pos, edge_color='black', width=3) - labels = nx.get_edge_attributes(graph, 'weight') - print(labels) - nx.draw_networkx_labels(graph, pos, font_size=20, font_family='sans-serif') - nx.draw_networkx_edge_labels(graph, pos, edge_labels=labels) - plt.axis('off') - plt.show() - - - - + if results[x][y][z] != 0: + target_x = results[x][y][z][0] + target_Y = results[x][y][z][1] + if z == 0: + insert_image(img_final, images[x][y], edge) + + if z == 1: + insert_image + if z == 2: + insert_image + if z ==3: + insert_image + results[target_x][target_Y][z] = 0 def graph_create(results): + graphs = [] graph = nx.Graph() for x in range(0, len(results)): for y in range(0, len(results[x])): + graph = nx.Graph() for z in range(4): - graph.add_edge((x,y), (results[x][y][z][1:]), weight=results[x][y][z][0]) - print(nx.shortest_path(graph, source=(0,0), target=(1,1))) - return graph - + match_x, match_y, match_score = get_best_match(x,y,results,z) + # Check if edge already exists, if it does only replace current edge if match_score is smaller than existing score (this can happen at edges) + if graph.get_edge_data((x,y),(match_x,match_y)) is not None and graph.get_edge_data((x,y),(match_x,match_y))['weight'] < match_score: + print('Edge already exists and has a lower score for piece', x, y) + else: + graph.add_edge((x,y), (match_x,match_y), weight=round(match_score,5), edge=z) + graphs.append(graph) + return graphs -def graph_init(results): - graphs = [] - nodes = [x for x in range(0,len(results))] - count = 0 - piece_num = 0 - graph = nx.Graph() - start = 0 - end = 10 - for y in range(0, len(results)): - graph = nx.Graph() - start = y*4 - end = start + 4 - for x in range(start, end): - print(x) - score, position = results[x] - piece_target = position // 4 - edge_target = position % 4 - piece_num = y - edge_num = x - print('Piece number ' + str(piece_num) + ' ' + ' to piece num ' + str(piece_target)) - graph.add_edge(piece_num, piece_target, weight=score, object={edge_num:edge_target}) - graphs.append(graph) - for graph in graphs: - pos = nx.spring_layout(graph) - #e1 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] == 0] - #e2 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] == 1] - #e3 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] ==2] - #e4 = [(u, v) for (u, v, d) in graph.edges(data=True) if d['object'][0] == 3] - nx.draw_networkx_nodes(graph, pos, node_size=300) - nx.draw_networkx_edges(graph, pos,edge_color='b',width=6) - #nx.draw_networkx_edges(graph, pos,edgelist=e1, edge_color='b',width=6) - #nx.draw_networkx_edges(graph, pos, edgelist=e2,edge_color='r',width=6) - #nx.draw_networkx_edges(graph, pos, edgelist=e3,edge_color='g',width=6) - #nx.draw_networkx_edges(graph, pos, edgelist=e4,edge_color='y',width=6) - labels = nx.get_edge_attributes(graph, 'weight') - labels2 = nx.get_edge_attributes(graph, 'object') - nx.draw_networkx_labels(graph, pos, font_size=20, font_family='sans-serif') - nx.draw_networkx_edge_labels(graph, pos, edge_labels=labels) - plt.axis('off') - plt.show()