Skip to content

Commit

Permalink
Updating apis for connections and operation registeration (#280)
Browse files Browse the repository at this point in the history
Updating apis for connections and operation registration
  • Loading branch information
chaoming0625 authored Oct 18, 2022
2 parents e5a17c8 + 71e1308 commit 87c0b86
Show file tree
Hide file tree
Showing 16 changed files with 529 additions and 369 deletions.
69 changes: 26 additions & 43 deletions brainpy/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def build_conn(self):
import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_mat(self, pre_size, post_size):
def build_mat(self, ):
return conn_matrix
def build_csr(self, pre_size, post_size):
def build_csr(self, ):
return post_ids, inptr
def build_coo(self, pre_size, post_size):
def build_coo(self, ):
return pre_ids, post_ids
"""
Expand Down Expand Up @@ -196,8 +196,6 @@ def check(self, structures: Union[Tuple, List, str]):
raise ConnectorError(f'Unknown synapse structure "{n}". '
f'Only {SUPPORTED_SYN_STRUCTURE} is supported.')



def _return_by_mat(self, structures, mat, all_data: dict):
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
Expand Down Expand Up @@ -332,70 +330,56 @@ def build_conn(self):
"""
pass

def require(self, *sizes_or_structures):
sizes_or_structures = list(sizes_or_structures)
pre_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
post_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
structures = sizes_or_structures
if isinstance(post_size, str):
structures.insert(0, post_size)
post_size = None
if isinstance(pre_size, str):
structures.insert(0, pre_size)
pre_size = None

version2_style = (pre_size is not None) and (post_size is not None)
if not version2_style:
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__(pre_size, post_size) '
f'before requiring connection data.')
if pre_size is None:
pre_size = self.pre_size
if post_size is None:
post_size = self.post_size
def require(self, *structures):
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__() '
f'before requiring connection data.')

self.check(structures)
if self.is_version2_style:
if len(structures) == 1:
if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'):
return self.build_csr(pre_size, post_size)
r = self.build_csr()
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'):
return self.build_mat(pre_size, post_size)
return bm.asarray(self.build_mat(), dtype=MAT_DTYPE)
elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)[0]
return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE)
elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)[1]
return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE)
elif len(structures) == 2:
if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)
r = self.build_coo()
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)

conn_data = dict(csr=None, ij=None, mat=None)
if not hasattr(self.build_coo, 'not_customized'):
conn_data['ij'] = self.build_coo(pre_size, post_size)
conn_data['ij'] = self.build_coo()
elif not hasattr(self.build_csr, 'not_customized'):
conn_data['csr'] = self.build_csr(pre_size, post_size)
conn_data['csr'] = self.build_csr()
elif not hasattr(self.build_mat, 'not_customized'):
conn_data['mat'] = self.build_mat(pre_size, post_size)
conn_data['mat'] = self.build_mat()

else:
conn_data = self.build_conn()
return self.make_returns(structures, conn_data)

def requires(self, *sizes_or_structures):
return self.require(*sizes_or_structures)
def requires(self, *structures):
return self.require(*structures)

@tools.not_customized
def build_mat(self, pre_size=None, post_size=None):
def build_mat(self):
pass

@tools.not_customized
def build_csr(self, pre_size=None, post_size=None):
def build_csr(self):
pass

@tools.not_customized
def build_coo(self, pre_size=None, post_size=None):
def build_coo(self):
pass


Expand Down Expand Up @@ -425,7 +409,6 @@ def __call__(self, pre_size, post_size=None):
else:
post_size = tuple(post_size)
self.pre_size, self.post_size = pre_size, post_size

self.pre_num = tools.size2num(self.pre_size)
self.post_num = tools.size2num(self.post_size)
return self
Expand Down
33 changes: 12 additions & 21 deletions brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from brainpy import tools
from brainpy.errors import ConnectorError
from .base import *
from .utils import *

__all__ = [
'MatConn',
Expand All @@ -34,11 +33,9 @@ def __call__(self, pre_size, post_size):
assert self.post_num == tools.size2num(post_size)
return self

def build_mat(self, pre_size=None, post_size=None):
pre_num = get_pre_num(self, pre_size)
post_num = get_post_num(self, post_size)
assert self.conn_mat.shape[0] == pre_num
assert self.conn_mat.shape[1] == post_num
def build_mat(self):
assert self.conn_mat.shape[0] == self.pre_num
assert self.conn_mat.shape[1] == self.post_num
return self.conn_mat


Expand Down Expand Up @@ -68,14 +65,12 @@ def __call__(self, pre_size, post_size):
f'the maximum id ({self.max_post}) of self.post_ids.')
return self

def build_coo(self, pre_size=None, post_size=None):
pre_num = get_pre_num(self, pre_size)
post_num = get_post_num(self, post_size)
if pre_num <= self.max_pre:
raise ConnectorError(f'pre_num ({pre_num}) should be greater than '
def build_coo(self):
if self.pre_num <= self.max_pre:
raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than '
f'the maximum id ({self.max_pre}) of self.pre_ids.')
if post_num <= self.max_post:
raise ConnectorError(f'post_num ({post_num}) should be greater than '
if self.post_num <= self.max_post:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.pre_ids, self.post_ids

Expand All @@ -91,16 +86,12 @@ def __init__(self, indices, inptr):
self.pre_num = self.inptr.size - 1
self.max_post = bm.max(self.indices)

def build_csr(self, pre_size=None, post_size=None):
pre_size = get_pre_size(self, pre_size)
post_size = get_post_size(self, post_size)
pre_num = np.prod(pre_size)
post_num = np.prod(post_size)
if pre_num != self.pre_num:
def build_csr(self):
if self.pre_num != self.pre_num:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with '
f'the shape of the sparse matrix.')
if post_num <= self.max_post:
raise ConnectorError(f'post_num ({post_num}) should be greater than '
if self.post_num <= self.max_post:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.indices, self.inptr

Expand Down
Loading

0 comments on commit 87c0b86

Please sign in to comment.