From 32c4916e1c90a052a8db57bea78c13f61add6d1a Mon Sep 17 00:00:00 2001 From: "intern-william@nodeflux-67" Date: Mon, 27 Jun 2022 09:43:52 +0000 Subject: [PATCH] update embedding GUI --- app_dash.py | 3 +- pages/embedding.py | 93 ++++++++++++++++++++++++++++++++-------------- requirements.txt | 1 + 3 files changed, 68 insertions(+), 29 deletions(-) diff --git a/app_dash.py b/app_dash.py index 8a07d4c..7eef4a4 100644 --- a/app_dash.py +++ b/app_dash.py @@ -1,7 +1,8 @@ import dash +import dash_bootstrap_components as dbc import config -app = dash.Dash(__name__, use_pages=True) +app = dash.Dash(__name__, use_pages=True, external_stylesheets=[dbc.themes.BOOTSTRAP]) app.layout = dash.html.Div([ dash.page_container ]) diff --git a/pages/embedding.py b/pages/embedding.py index e21aae8..1bd8aca 100644 --- a/pages/embedding.py +++ b/pages/embedding.py @@ -1,7 +1,9 @@ import dash +import dash_bootstrap_components as dbc import pandas as pd import plotly.express as px import plotly.graph_objects as go +import math import requests import os import config @@ -11,9 +13,7 @@ def create_figure(df, min_uniqueness, max_uniqueness, min_sqrt_area, max_sqrt_ar fig = px.scatter(df[mask], x='embeddings_x', y='embeddings_y', color='label', size='sqrt_area', custom_data=['id'], hover_data=['uniqueness']) figure = go.FigureWidget(fig) figure.update_layout( - autosize=True, - width=1080, - height=566, + dragmode='lasso' ) return figure @@ -27,33 +27,70 @@ def main(name): global df df = pd.read_pickle(file_path) uniqueness_range = (0, 1) - sqrt_area_range = (0, df['sqrt_area'].max()) + sqrt_area_range = (0, math.ceil(df['sqrt_area'].max())) figure = create_figure(df, uniqueness_range[0], uniqueness_range[1], sqrt_area_range[0], sqrt_area_range[1]) - return dash.html.Div(children=[ - dash.html.H3('Embedding Visualization'), - dash.dcc.Graph( - id='graph', - figure=figure - ), - dash.html.P(f'{name}', id='name'), - dash.html.P(id='num_sample'), - dash.html.P('Filter by uniqueness:'), - dash.dcc.RangeSlider( - id='uniqueness-slider', - min=0, max=1, step=0.001, - marks={0: '0', 1: '1'}, - value=[0, 1] - ), - dash.html.P('Filter by sqrt area:'), - dash.dcc.RangeSlider( - id='sqrt-area-slider', - min=0, max=df['sqrt_area'].max(), step=df['sqrt_area'].max()/1000, - marks={0: '0', df['sqrt_area'].max(): f'{df["sqrt_area"].max()}'}, - value=[0, df['sqrt_area'].max()] - ), - ]) + return dbc.Container( + children = [ + dash.html.H1('Embedding Visualization', style = {'text-align':'center'}), + dbc.Container( + children = [ + dash.html.P( + ''' + Each node in this graph represents a bounding box in the dataset. + Use 'box select' or 'lasso select' to select nodes. Double click on a label to excluded the other label. + The slider will filter nodes that do not meet the criteria. Note that the slider don't update the selection. + Selected nodes on this graph can be previewed on Fiftyone through the button on the bottom of this page. + ''' + ), + ], + className = 'm-3' + ), + dash.html.Div( + dash.dcc.Graph( + id='graph', + figure=figure, + responsive=True, + style={'width' : '100%', 'height': '100%'} + ), + style={'width' : '100%', 'height': '60%'} + ), + dbc.Container( + children = [ + dash.html.P(f'{name}', id='name'), + dash.html.P(id='num_sample'), + ], + className = 'm-3' + ), + dbc.Container( + children = [ + dash.html.B('Filter by Uniqueness'), + dash.dcc.RangeSlider( + id='uniqueness-slider', + min=uniqueness_range[0], max=uniqueness_range[1], step=(uniqueness_range[1]-uniqueness_range[0])/1000, + marks={uniqueness_range[0]: uniqueness_range[0], uniqueness_range[1]: uniqueness_range[1]}, + value=[uniqueness_range[0], uniqueness_range[1]] + ) + ], + className = 'm-3' + ), + dbc.Container( + children = [ + dash.html.B('Filter by Sqrt Area'), + dash.dcc.RangeSlider( + id='sqrt-area-slider', + min=sqrt_area_range[0], max=sqrt_area_range[1], step=(sqrt_area_range[1]-sqrt_area_range[0])/1000, + marks={sqrt_area_range[0]: sqrt_area_range[0], sqrt_area_range[1]: sqrt_area_range[1]}, + value=[sqrt_area_range[0], sqrt_area_range[1]] + ), + ], + className = 'm-3' + ), + dbc.Button('Open Fiftyone', color='light', className='m-3', href=f'{config.url}:{config.port["flask"]}/fiftyone/{name}', external_link=True, target='_blank'), + ], + className = 'p-5 vh-100 vw-100' + ) @dash.callback( dash.Output(component_id='graph', component_property='figure'), @@ -88,7 +125,7 @@ def update(name, input_value): } ) return f'Number of patches: {len(ids)}' - + dash.register_page(__name__, path_template="/embedding/") def layout(name=None): diff --git a/requirements.txt b/requirements.txt index fbe2a4e..012ec22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ Flask==2.1.2 dash==2.5.1 +dash-bootstrap-components==1.1.0 plotly==5.9.0 ipywidgets==7.7.1 torch==1.11.0