-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathHopfield.hs
315 lines (258 loc) · 11.8 KB
/
Hopfield.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE BangPatterns #-}
-- | Base Hopfield model, providing training and running.
module Hopfield.Hopfield (
Pattern
, Weights
, LearningType (Hebbian, Storkey)
-- * Hopfield data structure
, HopfieldData ()
, weights
, patterns
, buildHopfieldData
-- * Running
, update
, addPatterns
, repeatedUpdate
, updateChain
, matchPattern
, computeH
-- * Energy
, energy
) where
import Control.Monad
import Control.Monad.Random (MonadRandom)
import Data.Maybe
import Data.Vector ((!))
import qualified Data.Vector as V
import Data.Vector.Generic.Mutable (write)
import Hopfield.Common
import Hopfield.Util
data LearningType = Hebbian | Storkey deriving (Eq, Show, Read)
--make Hopefield data implement show
-- | Encapsulates the network weights together with the patterns that generate
-- it with the patterns which generate it
data HopfieldData = HopfieldData {
weights :: Weights -- ^ the weights of the network
, patterns :: [Pattern] -- ^ the patterns which were used to train it
} deriving (Show)
-- | Checks if weights and pattern given to the function satisfy their constraints,
-- if yes, calling the function, otherwise erroring out.
-- Usage: `checkWsPat (functionTakingWeightsAndPattern)`.
checkWsPat :: (Weights -> Pattern -> a) -> Weights -> Pattern -> a
checkWsPat f ws pat
| Just e <- validWeights ws = error e
| Just e <- validPattern pat = error e
| Just e <- validWeightsPatternSize ws pat = error e
| otherwise = f ws pat
-- | @update weights pattern@: Applies the update rule on @pattern@ for the
-- first updatable neuron given the Hopfield network (represented by @weights@).
--
-- Pre: @length weights == length pattern@
update :: MonadRandom m => Weights -> Pattern -> m (Maybe Pattern)
update = checkWsPat update_
-- | @repeatedUpdate weights pattern@: Performs repeated updates on the given
-- pattern until it reaches a stable state with respect to the Hopfield network
-- (represented by @weights@).
-- Pre: @length weights == length pattern@
repeatedUpdate :: (MonadRandom m) => Weights -> Pattern -> m Pattern
repeatedUpdate = checkWsPat repeatedUpdate_
-- | Computes the weighted sum of current neuron values, which will give us
-- the value of the neuron (by taking the sign)
computeH :: Weights -> Pattern -> Int -> Int
computeH ws pat i = checkWsPat (\w p -> computeH_ w p i) ws pat
-- | @energy weights pattern@: Computes the energy of a pattern given a Hopfield
-- network (represented by @weights@).
-- Pre: @length weights == length pattern@
energy :: Weights -> Pattern -> Double
energy = checkWsPat energy_
-- | @buildHopfieldData patterns@: Takes a list of patterns and
-- builds a Hopfield network (by training) in which these patterns are
-- stable states. The result of this function can be used to run a pattern
-- against the network, by using 'matchPattern'.
buildHopfieldData :: LearningType -> [Pattern] -> HopfieldData
buildHopfieldData _ [] = error "Train patterns are empty"
buildHopfieldData learningType pats
| first_len == 0
= error "Cannot have empty patterns"
| any (\x -> V.length x /= first_len) pats
= error "All training patterns must have the same length"
| otherwise
= HopfieldData (trainingFunction pats) pats
where
first_len = V.length (head pats)
trainingFunction = case learningType of
Hebbian -> train
Storkey -> trainStorkey
-- | @train patterns@: Trains and constructs network given a list of patterns
-- which are used to build the weight matrix. As a consequence, they will be
-- stable points in the network (by construction).
train :: [Pattern] -> Weights
train pats = vector2D ws
-- No need to check pats ws size, buildHopfieldData does it
where
ws = [ [ w i j ./. n | j <- [0 .. n-1] ] | i <- [0 .. n-1] ]
w i j
| i == j = 0
| otherwise = sum [ (pat ! i) * (pat ! j) | pat <- pats ]
n = V.length (head pats)
-- | See `computeH`.
computeH_ :: Weights -> Pattern -> Int -> Int
computeH_ ws pat i = {-# SCC "computeHall" #-} if weighted >= 0 then 1 else -1
where
weighted :: Double
wss = ws ! i
weighted = go 0 0.0
go :: Int -> Double -> Double
go !j !s | j == p = s
| otherwise = let w = wss `V.unsafeIndex` j
x = if pat `V.unsafeIndex` j > 0 then w
else -w
in go (j+1) (s+x)
p = {-# SCC "computeHvlength" #-} V.length pat
-- | See `update`.
-- The update is done by finding a neuron that will change its value given the
-- current state. The search for this neuron is done in a random manner:
-- pick up a random neuron, check if it is updatable: if so, update the pattern
-- by updating this neuron. If not, continue until an updatable neuron is found.
-- (Note: Initially the update was performed by obtaining a list of all
-- updatable neurons and then picking a random one. The current method is 2 times
-- faster)
update_ :: MonadRandom m => Weights -> Pattern -> m (Maybe Pattern)
update_ ws pat = do
randomIndices <- shuffle . toArray $ [0 .. V.length pat - 1]
-- TODO avoid Array -> List -> Vector conversion
return $ case firstUpdatable (V.fromList randomIndices) of
Nothing -> Nothing
Just index -> Just $ flipAtIndex pat index
where
firstUpdatable indices = go 0
where
go n
| n == V.length pat = Nothing
| pat ! i /= computeH_ ws pat i = Just i
| otherwise = go (n+1)
where i = indices ! n
flipAtIndex vec index = let val = vec ! index -- seq only brings small saving here
in val `seq` V.modify (\v -> write v index (-val)) vec
-- | See `repeatedUpdate`.
repeatedUpdate_ :: (MonadRandom m) => Weights -> Pattern -> m Pattern
repeatedUpdate_ ws pat = repeatUntilNothing (update_ ws) pat
-- | @matchPatterns hopfieldData pattern@:
-- Computes the stable state of a pattern given a Hopfield network(represented
-- by @weights@) and tries to find a match in a list of patterns which are
-- stored in @hopfieldData@.
-- Returns:
--
-- The index of the matching pattern in @patterns@, if a match exists
-- The converged pattern (the stable state), otherwise
--
-- Pre: @length weights == length pattern@
matchPattern :: MonadRandom m => HopfieldData -> Pattern -> m (Either Pattern Int)
matchPattern (HopfieldData ws pats) pat = do
converged_pattern <- repeatedUpdate_ ws pat
return $ findInList pats converged_pattern
-- | Like `repeatedUpdate`, but collecting all patterns until convergence.
-- The last pattern in the list is the converged pattern.
-- The argument pattern is NOT prepended to the result list.
--
-- POST: The returned list is not empty.
updateChain :: (MonadRandom m) => HopfieldData -> Pattern -> m [Pattern]
updateChain (HopfieldData ws _pats) pat
| Just e <- validPattern pat = error e
| otherwise = (pat:) `liftM` unfoldrSelfM (update_ ws) pat
-- | Stores patterns in an already trained network. One has to ensure that this
-- function is not over used, as this will decrease the capacity of the network.
addPatterns :: LearningType -> HopfieldData -> [Pattern] -> HopfieldData
addPatterns learning (HopfieldData ws pats) addedPats
| any (isJust . validPattern) addedPats = error "invalid patterns in addMultiplePatterns"
| any (isJust . validWeightsPatternSize ws) addedPats = error "pattern does not match weights in addMultiplePatterns"
| otherwise = HopfieldData new_ws (pats ++ addedPats)
where new_ws = foldl (updateWeightsGivenNewPattern learning) ws addedPats
-- Updates the weight matrix when a new pattern is stored in the network
updateWeightsGivenNewPattern :: LearningType -> Weights -> Pattern -> Weights
updateWeightsGivenNewPattern Storkey ws pat = updateWeightsStorkey ws pat
updateWeightsGivenNewPattern Hebbian ws pat = vector2D updated_ws
where updated_ws = [ [ws ! i ! j + (pat ! i * pat ! j) ./. n | j <- neurons ] | i <- neurons]
n = V.length ws - 1
neurons = [0 .. n]
-- | See `energy`.
energy_ :: Weights -> Pattern -> Double
energy_ ws pat = s / (-2.0)
where
p = V.length pat
w i j = ws ! i ! j
x i = pat ! i
s = sum [ w i j *. (x i * x j) | i <- [0 .. p-1], j <- [0 .. p-1] ]
-- | Checks if a pattern consists of only 1s and -1s.
-- Returns @Nothing@ on success, an error string on failure.
validPattern :: Pattern -> Maybe String
validPattern pat = case [ x | x <- V.toList pat, not (x == 1 || x == -1) ] of
[] -> Nothing
x:_ -> Just $ "Pattern contains invalid value " ++ show x
-- | @validWeightsPatternSize weights pattern@
-- Returns an error string in a Just if the @pattern@ is not compatible
-- with @weights@ and Nothing otherwise.
validWeightsPatternSize :: Weights -> Pattern -> Maybe String
validWeightsPatternSize ws pat
| V.length ws /= V.length pat = Just "Pattern size must match network size"
| otherwise = Nothing
-- Checks the validity of a weight matrix by ensuring:
-- * It is non-empty
--
-- * It is square
--
-- * It is symmetric
--
-- * All diagonal elements must be zero
-- These checks hold for both Hebbian and Storkey.
validWeights :: Weights -> Maybe String
validWeights ws
| n == 0
= Just "Weight matrix must be non-empty"
| any (\x -> V.length x /= n) $ V.toList ws
= Just "Weight matrix has to be a square matrix"
| any (/= 0) [ ws ! i ! i | i <- [0..n-1] ]
= Just "Weight matrix first diagonal must be zero"
| not $ and [ abs( (ws ! i ! j) - (ws ! j ! i) ) < 0.0001 | i <- [0..n-1], j <- [0..n-1] ]
= Just "Weight matrix must be symmetric"
| null [ abs (ws ! i ! j) > 1 | i <- [0..n-1], j <- [0..n-1] ]
= Just "Weights should be between (-1, 1)"
| otherwise = Nothing
where
n = V.length ws
-- Storkey training provides advantages for the Hopfield network as
-- it gives it bigger capacity and higher basins of attraction.
-- For more details see:
-- http://homepages.inf.ed.ac.uk/amos/publications/Storkey1997IncreasingtheCapacityoftheHopfieldNetworkwithoutSacrificingFunctionality.pdf
-- | @storkeyHiddenSum ws pat i j@ computes the value at indices @i@ @j@ in the
-- hidden matrix which is used for updating in the weight matrix during trainig
-- given the training pattern @pat@.
storkeyHiddenSum :: Weights -> Pattern -> Int -> Int -> Double
storkeyHiddenSum ws pat i j
= sum [ ws ! i ! k *. (pat ! k) | k <- [0 .. n - 1] , k /= i , k /= j]
where n = V.length ws
-- | @updateWeightsGivenIndicesStorkey ws pat i j@ computes the new value at
-- indices @i@ @j@ of the weights matrix for the training iteration of
-- pattern @pat@.
updateWeightsGivenIndicesStorkey :: Weights -> Pattern -> Int -> Int -> Double
updateWeightsGivenIndicesStorkey ws pat i j
| i == j = 0.0
| otherwise = ws ! i ! j + (1 :: Int) ./. n * (fromIntegral (pat ! i * (pat ! j)) - h j i *. (pat ! i) - h i j *. (pat ! j))
where n = V.length ws
h = storkeyHiddenSum ws pat
-- | @updateWeightsStorkey ws pat@ updates the weights matrix, given training
-- instance @pat@.
updateWeightsStorkey :: Weights -> Pattern -> Weights
updateWeightsStorkey ws pat
= vector2D [ [ updateWeightsGivenIndicesStorkey ws pat i j | j <- [0 ..n - 1] ] | i <- [0 ..n - 1] ]
where n = V.length ws
-- | @trainStorkey pats@ trains the Hopfield network by computing the weights
-- matrix by iterating trough all training instances (@pats@) and updating the
-- weights according to the Storkey learning rule.
trainStorkey :: [Pattern] -> Weights
-- No need to check pats ws size, buildHopfieldData does it
trainStorkey pats = foldl updateWeightsStorkey start_ws pats
where start_ws = vector2D $ replicate n $ replicate n 0
n = V.length $ head pats