Skip to content

Commit

Permalink
dash sounds
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 8, 2024
1 parent d953f6f commit 54d3c82
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 16 deletions.
File renamed without changes.
File renamed without changes.
153 changes: 137 additions & 16 deletions extra/dashboard/dashboard.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import dash
from dash import html, dcc
from dash import html, dcc, ctx
import plotly.graph_objs as go
from dash.dependencies import Input, Output
from dash.dependencies import Input, Output, State
import boto3
import json
from collections import defaultdict
import os
import base64
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px

s3 = boto3.client('s3')
BUCKET_NAME = 'exo-benchmarks'
Expand Down Expand Up @@ -51,21 +55,55 @@ def load_data_from_s3():

app.layout = html.Div([
html.H1('Benchmark Performance Dashboard'),
html.Button('Test Sound', id='test-sound-button', n_clicks=0),
html.Div(id='graphs-container'),
html.Audio(id='success-sound', src='assets/pokemon_evolve.mp3', preload="auto", style={'display': 'none'}),
html.Audio(id='failure-sound', src='assets/gta5_wasted.mp3', preload="auto", style={'display': 'none'}),
html.Audio(id='startup-sound', src='assets/pokemon_evolve.mp3', preload="auto", style={'display': 'none'}),
html.Div(id='audio-trigger', style={'display': 'none'}),
dcc.Store(id='previous-data', storage_type='memory'),
dcc.Interval(
id='interval-component',
interval=300000, # Update every 5 minutes
interval=10000, # Update every 10 seconds
n_intervals=0
)
])

@app.callback(
Output('graphs-container', 'children'),
Input('interval-component', 'n_intervals')
[Output('graphs-container', 'children'),
Output('previous-data', 'data'),
Output('audio-trigger', 'children')],
[Input('interval-component', 'n_intervals')],
[State('previous-data', 'data')]
)
def update_graphs(n):
def update_graphs(n, previous_data):
config_data = load_data_from_s3()
graphs = []
trigger_sound = None

if previous_data:
for config_name, data in config_data.items():
if config_name in previous_data and data and previous_data[config_name]:
current_prompt_tps = data[-1]['prompt_tps']
previous_prompt_tps = previous_data[config_name][-1]['prompt_tps']

# Add clear logging for TPS changes
if current_prompt_tps != previous_prompt_tps:
print("\n" + "="*50)
print(f"Config: {config_name}")
print(f"Previous TPS: {previous_prompt_tps}")
print(f"Current TPS: {current_prompt_tps}")
print(f"Change: {current_prompt_tps - previous_prompt_tps}")

if current_prompt_tps > previous_prompt_tps:
print("🔼 TPS INCREASED - Should play success sound")
trigger_sound = 'success'
elif current_prompt_tps < previous_prompt_tps:
print("🔽 TPS DECREASED - Should play failure sound")
trigger_sound = 'failure'

if current_prompt_tps != previous_prompt_tps:
print("="*50 + "\n")

for config_name, data in config_data.items():
timestamps = [d['timestamp'] for d in data]
Expand All @@ -74,8 +112,12 @@ def update_graphs(n):
commits = [d['commit'] for d in data]
run_ids = [d['run_id'] for d in data]

fig = go.Figure()
# Create subplot with 2 columns
fig = make_subplots(rows=1, cols=2,
subplot_titles=('Performance Over Time', 'Generation TPS Distribution'),
column_widths=[0.7, 0.3])

# Time series plot (left)
fig.add_trace(go.Scatter(
x=timestamps,
y=prompt_tps,
Expand All @@ -84,7 +126,7 @@ def update_graphs(n):
hovertemplate='Commit: %{text}<br>TPS: %{y}<extra></extra>',
text=commits,
customdata=run_ids
))
), row=1, col=1)

fig.add_trace(go.Scatter(
x=timestamps,
Expand All @@ -94,16 +136,55 @@ def update_graphs(n):
hovertemplate='Commit: %{text}<br>TPS: %{y}<extra></extra>',
text=commits,
customdata=run_ids
))
), row=1, col=1)

# Calculate statistics
gen_tps_array = np.array(generation_tps)
stats = {
'Mean': np.mean(gen_tps_array),
'Std Dev': np.std(gen_tps_array),
'Min': np.min(gen_tps_array),
'Max': np.max(gen_tps_array)
}

# Histogram plot (right)
fig.add_trace(go.Histogram(
x=generation_tps,
name='Generation TPS Distribution',
nbinsx=10,
showlegend=False
), row=1, col=2)

# Add statistics as annotations
stats_text = '<br>'.join([f'{k}: {v:.2f}' for k, v in stats.items()])
fig.add_annotation(
x=0.98,
y=0.98,
xref='paper',
yref='paper',
text=stats_text,
showarrow=False,
font=dict(size=12),
align='left',
bgcolor='rgba(255, 255, 255, 0.8)',
bordercolor='black',
borderwidth=1
)

fig.update_layout(
title=f'Performance Metrics - {config_name}',
xaxis_title='Timestamp',
yaxis_title='Tokens per Second',
height=500,
showlegend=True,
hovermode='x unified',
clickmode='event'
)

# Update x and y axis labels
fig.update_xaxes(title_text='Timestamp', row=1, col=1)
fig.update_xaxes(title_text='Generation TPS', row=1, col=2)
fig.update_yaxes(title_text='Tokens per Second', row=1, col=1)
fig.update_yaxes(title_text='Count', row=1, col=2)

graphs.append(html.Div([
dcc.Graph(
figure=fig,
Expand All @@ -112,19 +193,59 @@ def update_graphs(n):
)
]))

return graphs
return graphs, config_data, trigger_sound

@app.callback(
Output('_', 'children'),
Input({'type': 'dynamic-graph', 'index': dash.ALL}, 'clickData')
Output('graphs-container', 'children', allow_duplicate=True),
Input({'type': 'dynamic-graph', 'index': dash.ALL}, 'clickData'),
prevent_initial_call=True
)
def handle_click(clickData):
if clickData and clickData['points'][0].get('customdata'):
run_id = clickData['points'][0]['customdata']
if clickData and clickData[0] and clickData[0]['points'][0].get('customdata'):
run_id = clickData[0]['points'][0]['customdata']
url = f'https://github.com/exo-explore/exo/actions/runs/{run_id}'
import webbrowser
webbrowser.open_new_tab(url)
return dash.no_update

app.clientside_callback(
"""
function(trigger, test_clicks) {
if (!trigger && !test_clicks) return window.dash_clientside.no_update;
if (test_clicks > 0 && dash_clientside.callback_context.triggered[0].prop_id.includes('test-sound-button')) {
console.log('Test button clicked');
const audio = document.getElementById('startup-sound');
if (audio) {
audio.currentTime = 0;
audio.play().catch(e => console.log('Error playing audio:', e));
}
} else if (trigger) {
console.log('Audio trigger received:', trigger);
if (trigger === 'success') {
console.log('Playing success sound');
const audio = document.getElementById('success-sound');
if (audio) {
audio.currentTime = 0;
audio.play().catch(e => console.log('Error playing success sound:', e));
}
} else if (trigger === 'failure') {
console.log('Playing failure sound');
const audio = document.getElementById('failure-sound');
if (audio) {
audio.currentTime = 0;
audio.play().catch(e => console.log('Error playing failure sound:', e));
}
}
}
return window.dash_clientside.no_update;
}
""",
Output('audio-trigger', 'children', allow_duplicate=True),
[Input('audio-trigger', 'children'),
Input('test-sound-button', 'n_clicks')],
prevent_initial_call=True
)

if __name__ == '__main__':
app.run_server(debug=True)
2 changes: 2 additions & 0 deletions extra/dashboard/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
boto3==1.35.76
dash==2.18.2
numpy
pandas

0 comments on commit 54d3c82

Please sign in to comment.