Skip to content

Commit

Permalink
update embedding GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
intern-william@nodeflux-67 committed Jun 27, 2022
1 parent 0101954 commit 32c4916
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 29 deletions.
3 changes: 2 additions & 1 deletion app_dash.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dash
import dash_bootstrap_components as dbc
import config

app = dash.Dash(__name__, use_pages=True)
app = dash.Dash(__name__, use_pages=True, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dash.html.Div([
dash.page_container
])
Expand Down
93 changes: 65 additions & 28 deletions pages/embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dash
import dash_bootstrap_components as dbc
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import math
import requests
import os
import config
Expand All @@ -11,9 +13,7 @@ def create_figure(df, min_uniqueness, max_uniqueness, min_sqrt_area, max_sqrt_ar
fig = px.scatter(df[mask], x='embeddings_x', y='embeddings_y', color='label', size='sqrt_area', custom_data=['id'], hover_data=['uniqueness'])
figure = go.FigureWidget(fig)
figure.update_layout(
autosize=True,
width=1080,
height=566,
dragmode='lasso'
)
return figure

Expand All @@ -27,33 +27,70 @@ def main(name):
global df
df = pd.read_pickle(file_path)
uniqueness_range = (0, 1)
sqrt_area_range = (0, df['sqrt_area'].max())
sqrt_area_range = (0, math.ceil(df['sqrt_area'].max()))

figure = create_figure(df, uniqueness_range[0], uniqueness_range[1], sqrt_area_range[0], sqrt_area_range[1])

return dash.html.Div(children=[
dash.html.H3('Embedding Visualization'),
dash.dcc.Graph(
id='graph',
figure=figure
),
dash.html.P(f'{name}', id='name'),
dash.html.P(id='num_sample'),
dash.html.P('Filter by uniqueness:'),
dash.dcc.RangeSlider(
id='uniqueness-slider',
min=0, max=1, step=0.001,
marks={0: '0', 1: '1'},
value=[0, 1]
),
dash.html.P('Filter by sqrt area:'),
dash.dcc.RangeSlider(
id='sqrt-area-slider',
min=0, max=df['sqrt_area'].max(), step=df['sqrt_area'].max()/1000,
marks={0: '0', df['sqrt_area'].max(): f'{df["sqrt_area"].max()}'},
value=[0, df['sqrt_area'].max()]
),
])
return dbc.Container(
children = [
dash.html.H1('Embedding Visualization', style = {'text-align':'center'}),
dbc.Container(
children = [
dash.html.P(
'''
Each node in this graph represents a bounding box in the dataset.
Use 'box select' or 'lasso select' to select nodes. Double click on a label to excluded the other label.
The slider will filter nodes that do not meet the criteria. Note that the slider don't update the selection.
Selected nodes on this graph can be previewed on Fiftyone through the button on the bottom of this page.
'''
),
],
className = 'm-3'
),
dash.html.Div(
dash.dcc.Graph(
id='graph',
figure=figure,
responsive=True,
style={'width' : '100%', 'height': '100%'}
),
style={'width' : '100%', 'height': '60%'}
),
dbc.Container(
children = [
dash.html.P(f'{name}', id='name'),
dash.html.P(id='num_sample'),
],
className = 'm-3'
),
dbc.Container(
children = [
dash.html.B('Filter by Uniqueness'),
dash.dcc.RangeSlider(
id='uniqueness-slider',
min=uniqueness_range[0], max=uniqueness_range[1], step=(uniqueness_range[1]-uniqueness_range[0])/1000,
marks={uniqueness_range[0]: uniqueness_range[0], uniqueness_range[1]: uniqueness_range[1]},
value=[uniqueness_range[0], uniqueness_range[1]]
)
],
className = 'm-3'
),
dbc.Container(
children = [
dash.html.B('Filter by Sqrt Area'),
dash.dcc.RangeSlider(
id='sqrt-area-slider',
min=sqrt_area_range[0], max=sqrt_area_range[1], step=(sqrt_area_range[1]-sqrt_area_range[0])/1000,
marks={sqrt_area_range[0]: sqrt_area_range[0], sqrt_area_range[1]: sqrt_area_range[1]},
value=[sqrt_area_range[0], sqrt_area_range[1]]
),
],
className = 'm-3'
),
dbc.Button('Open Fiftyone', color='light', className='m-3', href=f'{config.url}:{config.port["flask"]}/fiftyone/{name}', external_link=True, target='_blank'),
],
className = 'p-5 vh-100 vw-100'
)

@dash.callback(
dash.Output(component_id='graph', component_property='figure'),
Expand Down Expand Up @@ -88,7 +125,7 @@ def update(name, input_value):
}
)
return f'Number of patches: {len(ids)}'

dash.register_page(__name__, path_template="/embedding/<name>")

def layout(name=None):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Flask==2.1.2
dash==2.5.1
dash-bootstrap-components==1.1.0
plotly==5.9.0
ipywidgets==7.7.1
torch==1.11.0
Expand Down

0 comments on commit 32c4916

Please sign in to comment.