Skip to content

Commit

Permalink
Merge branch 'main' into update_release_job
Browse files Browse the repository at this point in the history
  • Loading branch information
mgkwill authored Jul 25, 2023
2 parents 4eaf098 + a550c50 commit a695cdb
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ jobs:
repository: 'Lava'

- name: Run flakeheaven (flake8)
run: poetry run flakeheaven lint src/lava tests/
run: |
poetry run flakeheaven lint src/lava tests/
poetry run find tutorials/ -name '*.py' -exec flakeheaven lint {} \+
security-lint:
name: Security Lint Code
Expand Down
4 changes: 2 additions & 2 deletions src/lava/magma/core/callback_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def __init__(self,

def pre_run_callback(self,
board: NxBoard = None,
_var_id_to_var_model_map: dict = None
var_id_to_var_model_map: dict = None
) -> None:
for fx in self.pre_run_fxs:
fx(board)

def post_run_callback(self,
board: NxBoard = None,
_var_id_to_var_model_map: dict = None
var_id_to_var_model_map: dict = None
) -> None:
for fx in self.post_run_fxs:
fx(board)
2 changes: 1 addition & 1 deletion src/lava/magma/core/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def run(self,
raise ValueError("run_cfg must not be None when calling"
" Process.run() unless the process has already"
" been compiled.")
self.create_runtime(run_cfg, compile_config)
self.create_runtime(run_cfg=run_cfg, compile_config=compile_config)
self._runtime.start(condition)

def create_runtime(self, run_cfg: ty.Optional[RunConfig] = None,
Expand Down
19 changes: 14 additions & 5 deletions tutorials/end_to_end/convert_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _mean_input(num_neurons_exc, gamma, g_factor, weight, rate, bias):

return mean_inp


def _std_input(num_neurons_exc, gamma, g_factor, weight, rate):
'''
Calculate mean input to single neuron given mean excitatory weight.
Expand All @@ -62,6 +63,7 @@ def _std_input(num_neurons_exc, gamma, g_factor, weight, rate):
'''
return num_neurons_exc * (1 + gamma * g_factor**2) * weight ** 2 * rate


def _y_th(vth, mean, std, dv_exc, du_exc):
'''
Effective threshold, see Grytskyy et al. 2013.
Expand Down Expand Up @@ -89,6 +91,7 @@ def _y_th(vth, mean, std, dv_exc, du_exc):

return y_th


def _y_r(mean, std, dv_exc, du_exc):
'''
Effective reset, see Grytskyy et al. 2013.
Expand Down Expand Up @@ -116,12 +119,14 @@ def _y_r(mean, std, dv_exc, du_exc):

return y_r


def f(y):
'''
Derivative of transfer function of LIF neuron at given argument.
'''
return np.exp(y ** 2) * (1 + erf(y))


def _alpha(vth, mean, std, dv_exc, du_exc):
'''
Auxiliary variable describing contribution of weights for weight
Expand Down Expand Up @@ -152,6 +157,7 @@ def _alpha(vth, mean, std, dv_exc, du_exc):

return val


def _beta(vth, mean, std, dv_exc, du_exc):
'''
Auxiliary variable describing contribution of square of weights for
Expand All @@ -176,14 +182,16 @@ def _beta(vth, mean, std, dv_exc, du_exc):
Contribution of square of weights
'''
val = np.sqrt(np.pi) * (mean * dv_exc * 0.01) ** 2
val *= 1/(2 * std ** 2)
val *= 1 / (2 * std ** 2)
val *= (f(_y_th(vth, mean, std, dv_exc, du_exc)) * (vth - mean) / std
- f(_y_r(mean, std, dv_exc, du_exc)) * (-1 * mean) / std)

return val

def convert_rate_to_lif_params(shape_exc, dr_exc, bias_exc, shape_inh, dr_inh,
bias_inh, g_factor, q_factor, weights, **kwargs):

def convert_rate_to_lif_params(
shape_exc, dr_exc, bias_exc, shape_inh, dr_inh, bias_inh, g_factor,
q_factor, weights, **kwargs):
'''Convert rate parameters to LIF parameters.
The mapping is based on A unified view on weakly correlated recurrent
network, Grytskyy et al. 2013.
Expand Down Expand Up @@ -224,7 +232,8 @@ def convert_rate_to_lif_params(shape_exc, dr_exc, bias_exc, shape_inh, dr_inh,
gamma = float(num_neurons_exc) / float(num_neurons_inh)

# Assert that network is balanced.
assert gamma * g_factor > 1, "Network not balanced, increase g_factor"
if gamma * g_factor <= 1:
raise AssertionError("Network not balanced, increase g_factor")

# Set timescales of neurons.
dv_exc = 1 * dr_exc # Dynamics of membrane potential as fast as rate.
Expand Down Expand Up @@ -275,7 +284,7 @@ def func(weight):
mean_inp = _mean_input(num_neurons_exc, gamma,
g_factor, weight, rate, bias)
std_inp = _std_input(num_neurons_exc, gamma,
g_factor, weight, rate)
g_factor, weight, rate)
alpha = _alpha(vth_exc, mean_inp, std_inp, dv_exc, du_inh)
beta = _beta(vth_exc, mean_inp, std_inp, dv_exc, du_inh)

Expand Down
38 changes: 19 additions & 19 deletions tutorials/in_depth/three_factor_learning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,29 +185,29 @@ def generate_post_spikes(pre_spike_times, num_steps, spike_prob_post):
"""generates specific post synaptic spikes to
demonstrate potentiation and depression.
"""
pre_synaptic_spikes = np.where(pre_spike_times==1)[1]
pre_synaptic_spikes = np.where(pre_spike_times == 1)[1]

spike_raster_post = np.zeros((len(spike_prob_post), num_steps))

for ts in range(num_steps):
for pre_ts in pre_synaptic_spikes:
if ts in range(pre_ts, pre_ts+20):
if ts in range(pre_ts, pre_ts + 20):
if np.random.rand(1) < spike_prob_post[0]:
spike_raster_post[0][ts] = 1

for ts in range(num_steps):
for pre_ts in pre_synaptic_spikes:
if ts in range(pre_ts-12, pre_ts-2):
if ts in range(pre_ts - 12, pre_ts - 2):
if np.random.rand(1) < spike_prob_post[1]:
spike_raster_post[1][ts] = 1

return spike_raster_post


def plot_spikes(spikes, figsize, legend, colors, title, num_steps):
offsets = list(range(1, len(spikes) + 1))
num_x_ticks = np.arange(0, num_steps+1, 25)
num_x_ticks = np.arange(0, num_steps + 1, 25)

plt.figure(figsize=figsize)

plt.eventplot(positions=spikes,
Expand All @@ -223,7 +223,7 @@ def plot_spikes(spikes, figsize, legend, colors, title, num_steps):
plt.grid(which='minor', color='lightgrey', linestyle=':', linewidth=0.5)
plt.grid(which='major', color='lightgray', linewidth=0.8)
plt.minorticks_on()

plt.yticks(ticks=offsets, labels=legend)

plt.show()
Expand All @@ -232,26 +232,26 @@ def plot_spikes(spikes, figsize, legend, colors, title, num_steps):
def plot_time_series(time, time_series, ylabel, title, figsize, color):
plt.figure(figsize=figsize)
plt.step(time, time_series, color=color)

plt.title(title)
plt.xlabel("Time steps")
plt.grid(which='minor', color='lightgrey', linestyle=':', linewidth=0.5)
plt.grid(which='major', color='lightgray', linewidth=0.8)
plt.minorticks_on()

plt.ylabel(ylabel)

plt.show()


def plot_time_series_subplots(time, time_series_y1, time_series_y2, ylabel,
title, figsize, color, legend,
leg_loc="upper left"):
plt.figure(figsize=figsize)

plt.step(time, time_series_y1, label=legend[0], color=color[0])
plt.step(time, time_series_y2, label=legend[1], color=color[1])

plt.title(title)
plt.xlabel("Time steps")
plt.ylabel(ylabel)
Expand All @@ -261,20 +261,20 @@ def plot_time_series_subplots(time, time_series_y1, time_series_y2, ylabel,
plt.xlim(0, len(time_series_y1))

plt.legend(loc=leg_loc)

plt.show()


def plot_spikes_time_series(time, time_series, spikes, figsize, legend,
colors, title, num_steps):

offsets = list(range(1, len(spikes) + 1))
num_x_ticks = np.arange(0, num_steps+1, 25)
num_x_ticks = np.arange(0, num_steps + 1, 25)

plt.figure(figsize=figsize)

plt.subplot(211)
plt.eventplot(positions=spikes,
plt.eventplot(positions=spikes,
lineoffsets=offsets,
linelength=0.9,
colors=colors)
Expand All @@ -287,13 +287,13 @@ def plot_spikes_time_series(time, time_series, spikes, figsize, legend,
plt.grid(which='minor', color='lightgrey', linestyle=':', linewidth=0.5)
plt.grid(which='major', color='lightgray', linewidth=0.8)
plt.minorticks_on()

plt.yticks(ticks=offsets, labels=legend)
plt.tight_layout(pad=3.0)

plt.subplot(212)
plt.step(time, time_series, color=colors)

plt.title(title[0])
plt.xlabel("Time steps")
plt.grid(which='minor', color='lightgrey', linestyle=':', linewidth=0.5)
Expand All @@ -302,5 +302,5 @@ def plot_spikes_time_series(time, time_series, spikes, figsize, legend,
plt.margins(x=0)

plt.ylabel("Trace Value")

plt.show()

0 comments on commit a695cdb

Please sign in to comment.