diff --git a/python/circuitsvis/activations.py b/python/circuitsvis/activations.py index 1f634c2..8abd427 100644 --- a/python/circuitsvis/activations.py +++ b/python/circuitsvis/activations.py @@ -13,6 +13,9 @@ def text_neuron_activations( second_dimension_name: Optional[str] = "Neuron", first_dimension_labels: Optional[List[str]] = None, second_dimension_labels: Optional[List[str]] = None, + first_dimension_default: Optional[int] = 0, + second_dimension_default: Optional[int] = 0, + show_selectors: Optional[bool] = True, ) -> RenderedHTML: """Show activations (colored by intensity) for each token in a text or set of texts. @@ -54,4 +57,7 @@ def text_neuron_activations( secondDimensionName=second_dimension_name, firstDimensionLabels=first_dimension_labels, secondDimensionLabels=second_dimension_labels, + firstDimensionDefault=first_dimension_default, + secondDimensionDefault=second_dimension_default, + showSelectors=show_selectors, ) diff --git a/react/src/activations/TextNeuronActivations.tsx b/react/src/activations/TextNeuronActivations.tsx index 7540b0f..9e4f1b7 100644 --- a/react/src/activations/TextNeuronActivations.tsx +++ b/react/src/activations/TextNeuronActivations.tsx @@ -36,7 +36,10 @@ export function TextNeuronActivations({ firstDimensionName = "Layer", secondDimensionName = "Neuron", firstDimensionLabels, - secondDimensionLabels + secondDimensionLabels, + firstDimensionDefault = 0, + secondDimensionDefault = 0, + showSelectors = true }: TextNeuronActivationsProps) { // If there is only one sample (i.e. if tokens is an array of strings), cast tokens and activations to an array with // a single element @@ -68,8 +71,8 @@ export function TextNeuronActivations({ const [sampleNumbers, setSampleNumbers] = useState([ ...Array(samplesPerPage).keys() ]); - const [layerNumber, setLayerNumber] = useState(0); - const [neuronNumber, setNeuronNumber] = useState(0); + const [layerNumber, setLayerNumber] = useState(firstDimensionDefault); + const [neuronNumber, setNeuronNumber] = useState(secondDimensionDefault); useEffect(() => { // When the user changes the samplesPerPage, update the sampleNumbers @@ -96,77 +99,62 @@ export function TextNeuronActivations({ return ( - - - - - - - - - - - - - - - {/* Only show the sample selector if there is more than one sample */} - {numberOfSamples > 1 && ( + {showSelectors && ( + + - - )} - - - {/* Only show the sample per page selector if there is more than one sample */} - {numberOfSamples > 1 && ( - - )} - - + + + {/* Only show the sample per page selector if there is more than one sample */} + {numberOfSamples > 1 && ( + + + + + + + )} + + + )}