Skip to content

Commit

Permalink
Update card
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliiznak committed Sep 17, 2024
1 parent 7461f6b commit 4524da7
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 34 deletions.
3 changes: 0 additions & 3 deletions src/LearningProcessVisualizer/LearningProcessVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ const LearningProcessVisualizer: Component = () => {
<p>Learning Rate: {data.learningRate}</p> */}
</div>
);
case 'iteration':
console.log("Rendering iteration step");
return {/* <div>Iteration {data.iteration} completed, Loss: {data.loss?.toFixed(4)}</div>; */}
default:
console.log("Unknown step:", currentPhase);
return null;
Expand Down
35 changes: 16 additions & 19 deletions src/TrainingControl/TrainingControls.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Component, createEffect, createSignal, Show } from "solid-js";
import { css, keyframes } from "@emotion/css";
import { actions, store } from '../store';
import { actions, setStore, store } from '../store';
import TrainingStepsVisualizer from './TrainingStepsVisualizer';
import TrainingStatus from "./TrainingStatus";
import { colors } from '../styles/colors';
Expand Down Expand Up @@ -29,7 +29,7 @@ const styles = {
title: css`
font-size: ${typography.fontSize.xl};
font-weight: ${typography.fontWeight.bold};
margin-bottom: 1rem;
margin-bottom: 0.5rem;
color: ${colors.text};
`,
controlsContainer: css`
Expand Down Expand Up @@ -124,7 +124,6 @@ const styles = {
box-shadow: none;
}
`,
// New iteration indicator style
iterationIndicator: css`
text-align: center;
margin: 1rem 0;
Expand All @@ -133,7 +132,6 @@ const styles = {
color: ${colors.primary};
animation: ${fadeIn} 0.5s ease-in-out;
`,
// Enhanced notification style
notification: css`
position: fixed;
top: 20%;
Expand All @@ -152,18 +150,10 @@ const styles = {
const TrainingControls: Component = () => {
const [zoomRange, setZoomRange] = createSignal<[number, number]>([0, 100]);
const [chartType, setChartType] = createSignal<'bar' | 'line'>('line');
const [isLossCalculated, setIsLossCalculated] = createSignal(false);
const [showNotification, setShowNotification] = createSignal(false);
const [currentIteration, setCurrentIteration] = createSignal(1);

createEffect(() => {
const { currentPhase } = store.trainingState;
if (currentPhase === 'forward' || currentPhase === 'backward') {
setIsLossCalculated(false);
} else if (currentPhase === 'loss') {
setIsLossCalculated(true);
}
});


createEffect(() => {
if (
Expand All @@ -176,6 +166,8 @@ const TrainingControls: Component = () => {
setCurrentIteration(prev => prev + 1);
// Reset for the next iteration
actions.trainingStateReset();
// Set the current phase to 'idle' to enable the Forward button
setStore('trainingState', 'currentPhase', 'idle');
}
});

Expand All @@ -185,22 +177,22 @@ const TrainingControls: Component = () => {
return iteration / totalIterations;
};

const getLossColor = (loss: number) => {
const getLossColor = (loss: number | null) => {
if (loss === null) return colors.error;
if (loss < 0.2) return colors.success;
if (loss < 0.5) return colors.warning;
return colors.error;
};

const isForwardDisabled = () =>
(store.trainingState.currentPhase !== 'idle' && store.trainingState.currentPhase !== 'update') ||
store.trainingState.backwardStepGradients.length > 0;
store.trainingState.currentPhase !== 'idle';

const isLossDisabled = () =>
store.trainingState.forwardStepResults.length === 0 ||
store.trainingState.backwardStepGradients.length > 0;
store.trainingState.currentPhase !== 'forward';

const isBackwardDisabled = () => store.trainingState.currentPhase !== 'loss';
const isUpdateWeightsDisabled = () => store.trainingState.backwardStepGradients.length === 0;
const isUpdateWeightsDisabled = () => store.trainingState.currentPhase !== 'backward';
const isResetDisabled = () => store.trainingState.forwardStepResults.length === 0;

const singleStepForward = () => {
Expand All @@ -221,8 +213,13 @@ const TrainingControls: Component = () => {
const stepBackward = () => {
if (!isBackwardDisabled()) actions.stepBackward();
};

const updateWeights = () => {
if (!isUpdateWeightsDisabled()) actions.updateWeights();
if (!isUpdateWeightsDisabled()) {
actions.updateWeights();
// Set the current phase to 'update' to trigger the effect
setStore('trainingState', 'currentPhase', 'update');
}
};

const handleWheel = (e: WheelEvent) => {
Expand Down
12 changes: 8 additions & 4 deletions src/TrainingControl/WeightUpdateStep.tsx
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import { Component, For } from "solid-js";
import { css } from "@emotion/css";
import { FaSolidArrowRight, FaSolidArrowDown } from 'solid-icons/fa';
import { colors } from '../styles/colors';

const WeightUpdateStep: Component<{ oldWeights: number[], newWeights: number[] }> = (props) => {
const styles = {
container: css`
display: flex;
flex-direction: column;
align-items: center;
background-color: #e6f7ff;
border: 1px solid #91d5ff;
background-color: ${colors.surface};
border: 1px solid ${colors.border};
border-radius: 0.25rem;
padding: 0.5rem;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
Expand All @@ -21,17 +22,19 @@ const WeightUpdateStep: Component<{ oldWeights: number[], newWeights: number[] }
`,
stepIcon: css`
font-size: 1rem;
color: #1890ff;
color: ${colors.primary};
`,
stepLabel: css`
font-size: 0.75rem;
font-weight: bold;
margin-top: 0.25rem;
color: ${colors.text};
`,
weightList: css`
width: 100%;
margin-top: 0.25rem;
font-size: 0.625rem;
color: ${colors.textLight};
`,
weightItem: css`
display: flex;
Expand All @@ -41,9 +44,10 @@ const WeightUpdateStep: Component<{ oldWeights: number[], newWeights: number[] }
`,
weightValue: css`
font-family: monospace;
color: ${colors.text};
`,
arrow: css`
color: #52c41a;
color: ${colors.secondary};
margin: 0 0.125rem;
`,
};
Expand Down
21 changes: 14 additions & 7 deletions src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const initialState: AppState = {
currentLoss: null,
forwardStepResults: [],
backwardStepGradients: [],
weightUpdateResults: [],
lossHistory: [],
},

Expand All @@ -47,6 +48,8 @@ const initialState: AppState = {
layerOutputs: []
},

trainingRuns: [],

};

export const [store, setStore] = createStore(initialState);
Expand Down Expand Up @@ -98,6 +101,7 @@ function trainingStateReset() {
backwardStepGradients: [],
lossHistory: [],
currentLoss: null,
weightUpdateResults: [],
currentPhase: 'idle'
});
});
Expand Down Expand Up @@ -128,6 +132,7 @@ function singleStepForward() {

console.log("Forward step completed. Result:", result);
batch(() => {
setStore('trainingState', 'currentPhase', 'forward');
setStore('trainingState', 'forwardStepResults', [...store.trainingState.forwardStepResults, { input: result.input, output: result.output }]);
setStore('simulationResult', { input: result.input, output: result.output, layerOutputs: layerOutputs });
});
Expand Down Expand Up @@ -188,14 +193,13 @@ function stepBackward() {

if (result && Array.isArray(result)) {
console.log("Updating store with result");
try {
console.log('here stepBackward', result)


batch(() => {
setStore('trainingState', 'currentPhase', 'backward');
setStore('trainingState', 'backwardStepGradients', result);
console.log("Store updated successfully");
} catch (error) {
console.error("Error updating store:", error);
}
});

} else {
console.log("No valid result from stepBackward");
}
Expand All @@ -211,8 +215,11 @@ function updateWeights() {
const result = store.trainer.updateWeights(store.trainingConfig.learningRate);

setStore('trainingStepResult', result);
setStore('trainingState', 'weightUpdateResults', result); // Add this line
setStore('network', store.trainer.network);
setStore('trainingState', 'currentPhase', 'update');
setStore('trainingState', 'currentPhase', 'idle');

console.log("Weights updated successfully");
});
}

Expand Down
2 changes: 2 additions & 0 deletions src/styles/colors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ export const colors = {

// Add textDark if needed
textDark: 'var(--text-light)', // Example value

success: '#28a745', // Add this line
};
5 changes: 4 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ export type AppState = {

// Training state
trainingState: {
currentPhase: 'idle',
currentPhase: 'idle'| 'forward'| 'loss' | 'backward' | 'update',
iteration: 0,
currentLoss: null,
forwardStepResults: [],
backwardStepGradients: [],
weightUpdateResults: [],
lossHistory: [],
};

Expand All @@ -38,6 +39,8 @@ export type AppState = {
trainingRuns: TrainingRun[]; // Add this line
};

type TrainingRun = any

export type BackwardStepGradients = {
neuron: number;
weights: number;
Expand Down

0 comments on commit 4524da7

Please sign in to comment.