Skip to content

Commit

Permalink
Grouped classes and clusters in the figure's legend
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Troisemaine authored and Colin Troisemaine committed Jul 4, 2024
1 parent 82f4252 commit 1a832b4
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 30 deletions.
Binary file modified backend/__pycache__/server.cpython-310.pyc
Binary file not shown.
121 changes: 95 additions & 26 deletions backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from sklearn.tree import DecisionTreeClassifier
from models.PBNModel import PBNModel
from sklearn.manifold import TSNE
import plotly.graph_objects as go
from PyPDF2 import PdfMerger
import plotly.express as px
# import plotly.express as px
from flask_cors import CORS
from sklearn import tree
import pandas as pd
Expand All @@ -35,7 +36,7 @@
import gc
import re

os.environ["OMP_NUM_THREADS"] = '1' # For the k-means warning...
# os.environ["OMP_NUM_THREADS"] = '1' # For the k-means warning...

app = Flask(__name__)
logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -291,18 +292,51 @@ def getDatasetTSNE():

image_filename = image_datetime_string + '.png'

tsne_target_wrapped = wrap_list(tsne_target, separator='<br>')

session['last_sent_points'] = pd.DataFrame({'point_index_in_df': dataset.index[mask],
'point_class': tsne_target_wrapped})

fig = px.scatter(x=np.array(tsne_array[tsne_array.columns[0]]),
y=np.array(tsne_array[tsne_array.columns[1]]),
color=tsne_target_wrapped,
title="T-SNE of the original " + dataset_name + " dataset")
fig.update_layout(yaxis={'title': None},
tsne_target_wrapped = np.array(wrap_list(tsne_target, separator='<br>'))

session['last_sent_points'] = pd.DataFrame({'point_index_in_df': dataset.index[mask], 'point_class': tsne_target_wrapped})

# We will create two groups of items: known and unknown data, and plot the unknown data first
scatterplots_data = []

unknown_data_mask = tsne_target_wrapped == "Unknown"
scatterplots_data.append(
go.Scatter(x=np.array(tsne_array[tsne_array.columns[0]])[unknown_data_mask],
y=np.array(tsne_array[tsne_array.columns[1]])[unknown_data_mask],
mode='markers',
legendgroup="unknown",
legendgrouptitle_text="Unknown data",
name="Unknown",
visible=True)
)

known_data_mask = ~unknown_data_mask
known_tsne_array = tsne_array[known_data_mask]
known_tsne_target_wrapped = tsne_target_wrapped[known_data_mask]
for t in np.sort(np.unique(known_tsne_target_wrapped)): # And the items in this group are sorted alphabetically
scatterplots_data.append(
go.Scatter(x=np.array(known_tsne_array[known_tsne_array.columns[0]])[known_tsne_target_wrapped == t],
y=np.array(known_tsne_array[known_tsne_array.columns[1]])[known_tsne_target_wrapped == t],
mode='markers',
legendgroup="classes",
legendgrouptitle_text="Known classes",
name=t,
visible=True)
)

fig = go.Figure(data=scatterplots_data)
# fig.update_layout(legend=dict(groupclick="toggleitem")) # Toggle the visibility of just the item clicked on by the user, not the whole group

# Simple un-ordered alternative:
# fig = px.scatter(x=np.array(tsne_array[tsne_array.columns[0]]),
# y=np.array(tsne_array[tsne_array.columns[1]]),
# color=tsne_target_wrapped)

fig.update_layout(title="T-SNE of the " + dataset_name + " dataset",
yaxis={'title': None},
xaxis={'title': None},
margin=dict(l=0, r=0, t=40, b=0))

graphJSON = plotly.io.to_json(fig, pretty=True)
# fig.savefig(os.path.join(image_folder_path, image_filename), dpi=fig.dpi, bbox_inches='tight')

Expand Down Expand Up @@ -408,7 +442,7 @@ def runClustering():
clustering_prediction = kmeans_model.fit_predict(filtered_dataset[unknown_mask])

full_target = np.array(["Class " + str(t) for t in dataset[target_name]])
full_target[unknown_mask] = np.array(["Clust " + str(pred) for pred in clustering_prediction])
full_target[unknown_mask] = np.array(["Cluster " + str(pred) for pred in clustering_prediction])

saveClusteringResultsInSession(clustering_prediction, target_name, dataset, known_classes, unknown_classes,
selected_features)
Expand All @@ -433,7 +467,7 @@ def runClustering():
clustering_prediction = clustering_prediction.labels_

full_target = np.array(["Class " + str(t) for t in dataset[target_name]])
full_target[unknown_mask] = ["Clust " + str(pred) for pred in clustering_prediction]
full_target[unknown_mask] = ["Cluster " + str(pred) for pred in clustering_prediction]

saveClusteringResultsInSession(clustering_prediction, target_name, dataset, known_classes, unknown_classes,
selected_features)
Expand Down Expand Up @@ -613,19 +647,54 @@ def generateClusteringImage(dataset_name, model_name, show_unknown_only, full_ta

target_to_plot = full_target[mask]

tsne_target_wrapped = wrap_list(target_to_plot, separator='<br>')

session['last_sent_points'] = pd.DataFrame(
{'point_index_in_df': session['loaded_datasets'].get(dataset_name).index[mask],
'point_class': tsne_target_wrapped})

fig = px.scatter(x=np.array(tsne_array[tsne_array.columns[0]]),
y=np.array(tsne_array[tsne_array.columns[1]]),
color=tsne_target_wrapped,
title="T-SNE of the original " + dataset_name + " dataset colored by " + model_name)
fig.update_layout(yaxis={'title': None},
tsne_target_wrapped = np.array(wrap_list(target_to_plot, separator='<br>'))

session['last_sent_points'] = pd.DataFrame({'point_index_in_df': session['loaded_datasets'].get(dataset_name).index[mask], 'point_class': tsne_target_wrapped})

# We will create two groups of items: clusters and classes, and plot the clusters first
scatterplots_data = []

clusters_data_mask = np.array([t.startswith("Cluster ") for t in tsne_target_wrapped])
clusters_tsne_target_wrapped = tsne_target_wrapped[clusters_data_mask]
clusters_tsne_array = tsne_array[clusters_data_mask]
for t in np.sort(np.unique(clusters_tsne_target_wrapped)): # And the items in the group are sorted alphabetically
scatterplots_data.append(
go.Scatter(x=np.array(clusters_tsne_array[clusters_tsne_array.columns[0]])[clusters_tsne_target_wrapped == t],
y=np.array(clusters_tsne_array[clusters_tsne_array.columns[1]])[clusters_tsne_target_wrapped == t],
mode='markers',
legendgroup="clusters",
legendgrouptitle_text="Generated clusters",
name=t,
visible=True)
)

classes_data_mask = ~clusters_data_mask
classes_tsne_target_wrapped = tsne_target_wrapped[classes_data_mask]
classes_tsne_array = tsne_array[classes_data_mask]
for t in np.sort(np.unique(classes_tsne_target_wrapped)):
scatterplots_data.append(
go.Scatter(x=np.array(classes_tsne_array[classes_tsne_array.columns[0]])[classes_tsne_target_wrapped == t],
y=np.array(classes_tsne_array[classes_tsne_array.columns[1]])[classes_tsne_target_wrapped == t],
mode='markers',
legendgroup="classes",
legendgrouptitle_text="Known classes",
name=t,
visible=True)
)

fig = go.Figure(data=scatterplots_data)
# fig.update_layout(legend=dict(groupclick="toggleitem")) # Toggle the visibility of just the item clicked on by the user, not the whole group

# Simple un-ordered alternative:
# fig = px.scatter(x=np.array(tsne_array[tsne_array.columns[0]]),
# y=np.array(tsne_array[tsne_array.columns[1]]),
# color=tsne_target_wrapped)

fig.update_layout(title="T-SNE of the " + dataset_name + " dataset colored by " + model_name,
yaxis={'title': None},
xaxis={'title': None},
margin=dict(l=0, r=0, t=40, b=0))

graphJSON = plotly.io.to_json(fig, pretty=True)
# fig.savefig(os.path.join(image_folder_path, image_filename), dpi=fig.dpi, bbox_inches='tight')

Expand Down Expand Up @@ -898,7 +967,7 @@ def getThreadResults():
clustering_prediction = model.predict_new_data(np.array(dataset[selected_features])[unknown_mask])

full_target = np.array(["Class " + str(t) for t in dataset[target_name]])
full_target[unknown_mask] = np.array(["Clust " + str(pred) for pred in clustering_prediction])
full_target[unknown_mask] = np.array(["Cluster " + str(pred) for pred in clustering_prediction])

saveClusteringResultsInSession(clustering_prediction, target_name, dataset, known_classes, unknown_classes,
selected_features)
Expand Down
73 changes: 71 additions & 2 deletions frontend/src/components/DataVisualization.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ import { faRotate } from "@fortawesome/free-solid-svg-icons";


class DataVisualization extends React.Component {

// constructor(props) {
// super(props);
//
// this.state = {
// lastLegendClickTime: 0
// }
// }

render() {
return (
<Container style={{height:"100%"}}>
Expand Down Expand Up @@ -57,8 +66,10 @@ class DataVisualization extends React.Component {
layout={this.props.image_to_display === null ? {} : this.props.image_to_display.layout}
style={{height: "97%", width: "100%", objectFit: "contain"}}
alt="T-SNE of the data"
onClick={(e) => this.props.handlePointClick(e)}
onLegendClick={(e) => this.handleLegendClick(e)}
onClick={(e) => {this.props.handlePointClick(e)}}
useResizeHandler={true}
config={{responsive: true}}
/>
</div>
</center>
Expand Down Expand Up @@ -95,6 +106,64 @@ class DataVisualization extends React.Component {
</Container>
)
}

handleLegendClick(event) {
// const currentTime = new Date().getTime();
// const timeDiff = currentTime - this.state.lastLegendClickTime;
// this.setState({lastLegendClickTime: currentTime});
//
// if (timeDiff < 300) {
// // Double click behavior
// console.log('Double click on legend:', event);
// } else {
// // Single click behavior
// console.log('Single click on legend:', event);
// }

const group_title = event.node.__data__[0].groupTitle

// If the clicked element has a group_title, it's a legend group title
if (group_title) {
const clicked_legend_group_title = group_title.text

const some_traces_are_visible = this.props.image_to_display.data.some(trace =>
(clicked_legend_group_title === "Generated clusters" && trace.legendgroup === "clusters" && trace.visible === true)
|| (clicked_legend_group_title === "Unknown data" && trace.legendgroup === "unknown" && trace.visible === true)
|| (clicked_legend_group_title === "Known classes" && trace.legendgroup === "classes" && trace.visible === true)
)

const updatedData = this.props.image_to_display.data.map(trace => {
if ((clicked_legend_group_title === "Generated clusters" && trace.legendgroup === "clusters")
|| (clicked_legend_group_title === "Unknown data" && trace.legendgroup === "unknown")
|| (clicked_legend_group_title === "Known classes" && trace.legendgroup === "classes")) {
// If some traces in this group are visible, set them all to 'legendonly'
// Otherwise, set all their visibility to true
return { ...trace, visible: some_traces_are_visible ? 'legendonly' : true }
} else {
return trace
}
})

// Update the plot with the new data
this.props.updateImageToDisplayData(updatedData)
}
// Otherwise, it's an individual element of the legend
else {
const traceIndex = event.curveNumber; // This is the index of the clicked element in the legend
const updatedData = this.props.image_to_display.data.map((trace, i) => {
if (i === traceIndex) {
return {...trace, visible: trace.visible === true ? 'legendonly' : true}
} else {
return trace
}
})

// Update the plot with the new data
this.props.updateImageToDisplayData(updatedData)
}

return false; // Prevent default legend item toggle behavior
}
}

export default DataVisualization;
export default DataVisualization;
15 changes: 13 additions & 2 deletions frontend/src/components/FullPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ class FullPage extends React.Component {
if (serverPromise.status === 200) {
serverPromise.json().then(response => {
this.setState({selected_class_feature: feature_name})
const new_formatted_class_values = response['unique_values'].map((feature, index) => ({"name": feature, "checked": true, "index": index, "used": true}))
const sorted_unique_values = response['unique_values'].sort() // We want the class modalities to be sorted alphabetically
const new_formatted_class_values = sorted_unique_values.map((feature, index) => ({"name": feature, "checked": true, "index": index, "used": true}))
this.setState({class_values_to_display: new_formatted_class_values})
this.setState({search_filtered_unique_values_list: this.getUpdatedFilteredList(new_formatted_class_values, this.state.unique_values_search_query)})
this.setState({n_known_classes: this.getNumberOfCheckedValues(new_formatted_class_values)})
Expand Down Expand Up @@ -1111,7 +1112,7 @@ class FullPage extends React.Component {

onClearCacheButtonClick = () => {
Swal.fire({
title: 'Clear the server\' cached data',
title: 'Clear the server\'s cached data',
text: "Clear the computed t-SNEs and saved images from the server's files. The processing time of the next requests will increase.",
showDenyButton: true,
confirmButtonText: 'Clear',
Expand Down Expand Up @@ -1175,6 +1176,15 @@ class FullPage extends React.Component {
})
}

updateImageToDisplayData = updatedData => {
this.setState(prevState => ({
image_to_display: {
...prevState.image_to_display,
data: updatedData
}
}));
};

render() {
return (
<Row style={{height: '100%', width:"99%"}} className="d-flex flex-row justify-content-center align-items-center">
Expand Down Expand Up @@ -1221,6 +1231,7 @@ class FullPage extends React.Component {
<Col className="col-lg-6 col-12 d-flex flex-column justify-content-center" style={{height: "98%"}}>
<Row className="my_row mx-lg-0 mb-lg-0 py-2 d-flex flex-row" style={{flexGrow:'1', height:"100%"}}>
<DataVisualization image_to_display={this.state.image_to_display}
updateImageToDisplayData={this.updateImageToDisplayData}
onRawDataButtonClick={this.onRawDataButtonClick}

onShowUnknownOnlySwitchChange={this.onShowUnknownOnlySwitchChange}
Expand Down

0 comments on commit 1a832b4

Please sign in to comment.