diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 3e4dfd3..3c512d7 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -16,6 +16,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install antlr4-python3-runtime==4.9.1
+ pip3 install termcolor
- name: Run all the Python tests
run: python3 tests/test_all.py
pep8:
@@ -29,4 +30,3 @@ jobs:
with:
arguments: >-
--exclude=.svn,CVS,.bzr,.hg,.git,zzantlr
- --ignore=E121,E123,E126,E133,E226,E241,E242,E704,W503,W504,W505,W191,E101,E128
diff --git a/RASP_support/DrawCompFlow.py b/RASP_support/DrawCompFlow.py
index 6679f2a..fc5b0f2 100644
--- a/RASP_support/DrawCompFlow.py
+++ b/RASP_support/DrawCompFlow.py
@@ -1,5 +1,5 @@
from .FunctionalSupport import Unfinished, guarded_contains, base_tokens, \
- tokens_asis
+ tokens_asis
from .Support import clean_val
import os
import string
@@ -24,589 +24,589 @@
def windows_path_cleaner(s):
- if os.name == "nt": # is windows
- validchars = "-_.() "+string.ascii_letters+string.digits
+ if os.name == "nt": # is windows
+ validchars = "-_.() "+string.ascii_letters+string.digits
- def fix(c):
- return c if c in validchars else "."
- return "".join([fix(c) for c in s])
- else:
- return s
+ def fix(c):
+ return c if c in validchars else "."
+ return "".join([fix(c) for c in s])
+ else:
+ return s
def colour_scheme(row_type):
- if row_type == INPUT:
- return 'gray', 'gray', 'gray'
- if row_type == QVAR:
- return 'palegreen4', 'mediumseagreen', 'palegreen1'
- elif row_type == KVAR:
- return 'deepskyblue3', 'darkturquoise', 'darkslategray1'
- elif row_type == VVAR:
- return 'palevioletred3', 'palevioletred2', 'lightpink'
- elif row_type == VREAL:
- return 'plum4', 'plum3', 'thistle2'
- elif row_type == RES:
- return 'lightsalmon3', 'burlywood', 'burlywood1'
- else:
- raise Exception("unknown row type: "+str(row_type))
+ if row_type == INPUT:
+ return 'gray', 'gray', 'gray'
+ if row_type == QVAR:
+ return 'palegreen4', 'mediumseagreen', 'palegreen1'
+ elif row_type == KVAR:
+ return 'deepskyblue3', 'darkturquoise', 'darkslategray1'
+ elif row_type == VVAR:
+ return 'palevioletred3', 'palevioletred2', 'lightpink'
+ elif row_type == VREAL:
+ return 'plum4', 'plum3', 'thistle2'
+ elif row_type == RES:
+ return 'lightsalmon3', 'burlywood', 'burlywood1'
+ else:
+ raise Exception("unknown row type: "+str(row_type))
QVAR, KVAR, VVAR, VREAL, RES, INPUT = [
- "QVAR", "KVAR", "VVAR", "VREAL", "RES", "INPUT"]
+ "QVAR", "KVAR", "VVAR", "VREAL", "RES", "INPUT"]
POSS_ROWS = [QVAR, KVAR, VVAR, VREAL, RES, INPUT]
ROW_NAMES = {QVAR: "Me", KVAR: "Other", VVAR: "X",
- VREAL: "f(X)", RES: "FF", INPUT: ""}
+ VREAL: "f(X)", RES: "FF", INPUT: ""}
def UnfinishedFunc(f):
- setattr(Unfinished, f.__name__, f)
+ setattr(Unfinished, f.__name__, f)
@UnfinishedFunc
def last_val(self):
- return self.last_res.get_vals()
+ return self.last_res.get_vals()
def makeQKStable(qvars, kvars, select, ref_in_g):
- qvars = [q.last_val() for q in qvars]
- kvars = [k.last_val() for k in kvars]
- select = select.last_val()
- q_val_len, k_val_len = len(select), len(select[0])
-
- qvars_skip = len(kvars)
- kvars_skip = len(qvars)
- _, _, qvars_colour = colour_scheme(QVAR)
- _, _, kvars_colour = colour_scheme(KVAR)
- # select has qvars along the rows and kvars along the columns, so we'll do
- # the same. i.e. top rows will just be the kvars and first columns will
- # just be the qvars.
- # if (not qvars) and (not kvars):
- # # no qvars or kvars -> full select -> dont waste space drawing.
- # num_rows, num_columns = 0, 0
- # pass
- # else:
- # num_rows = qvars_skip+(len(qvars[0]) if qvars else 1)
- # num_columns = kvars_skip+(len(kvars[0]) if kvars else 1)
- num_rows = qvars_skip+q_val_len
- num_columns = kvars_skip+k_val_len
-
- select_cells = {i: [CellVals('', head_color, j, i)
- for j in range(num_columns)]
- for i in range(num_rows)}
-
- for i, seq in enumerate(kvars):
- for j, v in enumerate(seq):
- vals = CellVals(v, kvars_colour, i, j+kvars_skip)
- select_cells[i][j + kvars_skip] = vals
- for j, seq in enumerate(qvars):
- for i, v in enumerate(seq):
- vals = CellVals(v, qvars_colour, i+qvars_skip, j)
- select_cells[i + qvars_skip][j] = vals
-
- for i in range(num_rows-qvars_skip): # i goes over the q_var values
- for j in range(num_columns-kvars_skip): # j goes over the k_var values
- v = select[i][j]
- colour = select_on_colour if v else select_off_colour
- select_cells[i+qvars_skip][j+kvars_skip] = CellVals(
- v, colour, i+qvars_skip, j+kvars_skip, select_internal=True)
-
- # TODO: make an ugly little q\k triangle thingy in the top corner
- return GridTable(select_cells, ref_in_g)
+ qvars = [q.last_val() for q in qvars]
+ kvars = [k.last_val() for k in kvars]
+ select = select.last_val()
+ q_val_len, k_val_len = len(select), len(select[0])
+
+ qvars_skip = len(kvars)
+ kvars_skip = len(qvars)
+ _, _, qvars_colour = colour_scheme(QVAR)
+ _, _, kvars_colour = colour_scheme(KVAR)
+ # select has qvars along the rows and kvars along the columns, so we'll do
+ # the same. i.e. top rows will just be the kvars and first columns will
+ # just be the qvars.
+ # if (not qvars) and (not kvars):
+ # # no qvars or kvars -> full select -> dont waste space drawing.
+ # num_rows, num_columns = 0, 0
+ # pass
+ # else:
+ # num_rows = qvars_skip+(len(qvars[0]) if qvars else 1)
+ # num_columns = kvars_skip+(len(kvars[0]) if kvars else 1)
+ num_rows = qvars_skip+q_val_len
+ num_columns = kvars_skip+k_val_len
+
+ select_cells = {i: [CellVals('', head_color, j, i)
+ for j in range(num_columns)]
+ for i in range(num_rows)}
+
+ for i, seq in enumerate(kvars):
+ for j, v in enumerate(seq):
+ vals = CellVals(v, kvars_colour, i, j+kvars_skip)
+ select_cells[i][j + kvars_skip] = vals
+ for j, seq in enumerate(qvars):
+ for i, v in enumerate(seq):
+ vals = CellVals(v, qvars_colour, i+qvars_skip, j)
+ select_cells[i + qvars_skip][j] = vals
+
+ for i in range(num_rows-qvars_skip): # i goes over the q_var values
+ for j in range(num_columns-kvars_skip): # j goes over the k_var values
+ v = select[i][j]
+ colour = select_on_colour if v else select_off_colour
+ select_cells[i+qvars_skip][j+kvars_skip] = CellVals(
+ v, colour, i+qvars_skip, j+kvars_skip, select_internal=True)
+
+ # TODO: make an ugly little q\k triangle thingy in the top corner
+ return GridTable(select_cells, ref_in_g)
class CellVals:
- def __init__(self, val, colour, i_row, i_col, select_internal=False,
- known_portstr=None):
- def mystr(v):
- if isinstance(v, bool):
- if select_internal:
- return ' ' if v else ' ' # color gives it all!
- else:
- return 'T' if v else 'F'
- if isinstance(v, float):
- v = clean_val(v, 3)
- if isinstance(v, int) and len(str(v)) == 1:
- v = " "+str(v) # for pretty square selectors
- return str(v).replace("<", "<").replace(">", ">")
- self.val = mystr(val)
- self.colour = colour
- if None is known_portstr:
- self.portstr = "_col"+str(i_col)+"_row"+str(i_row)
- else:
- self.portstr = known_portstr
-
- def __str__(self):
- return '
' + self.val+' | '
+ def __init__(self, val, colour, i_row, i_col, select_internal=False,
+ known_portstr=None):
+ def mystr(v):
+ if isinstance(v, bool):
+ if select_internal:
+ return ' ' if v else ' ' # color gives it all!
+ else:
+ return 'T' if v else 'F'
+ if isinstance(v, float):
+ v = clean_val(v, 3)
+ if isinstance(v, int) and len(str(v)) == 1:
+ v = " "+str(v) # for pretty square selectors
+ return str(v).replace("<", "<").replace(">", ">")
+ self.val = mystr(val)
+ self.colour = colour
+ if None is known_portstr:
+ self.portstr = "_col"+str(i_col)+"_row"+str(i_row)
+ else:
+ self.portstr = known_portstr
+
+ def __str__(self):
+ return '' + self.val+' | '
class GridTable:
- def __init__(self, cellvals, ref_in_g):
- self.ref_in_g = ref_in_g
- self.cellvals = cellvals
- self.numcols = len(cellvals.get(0, []))
- self.numrows = len(cellvals)
- self.empty = 0 in [self.numcols, self.numrows]
+ def __init__(self, cellvals, ref_in_g):
+ self.ref_in_g = ref_in_g
+ self.cellvals = cellvals
+ self.numcols = len(cellvals.get(0, []))
+ self.numrows = len(cellvals)
+ self.empty = 0 in [self.numcols, self.numrows]
- def to_str(self, transposed=False):
- ii = sorted(list(self.cellvals.keys()))
- rows = [self.cellvals[i] for i in ii]
+ def to_str(self, transposed=False):
+ ii = sorted(list(self.cellvals.keys()))
+ rows = [self.cellvals[i] for i in ii]
- def cells2row(cells):
- return ''+''.join(map(str, cells))+'
'
- return '<' + ''.join(map(cells2row, rows)) \
- + '
>'
+ def cells2row(cells):
+ return ''+''.join(map(str, cells))+'
'
+ return '<' + ''.join(map(cells2row, rows)) \
+ + '
>'
- def bottom_left_portstr(self):
- return self.access_portstr(0, -1)
+ def bottom_left_portstr(self):
+ return self.access_portstr(0, -1)
- def bottom_right_portstr(self):
- return self.access_portstr(-1, -1)
+ def bottom_right_portstr(self):
+ return self.access_portstr(-1, -1)
- def top_left_portstr(self):
- return self.access_portstr(0, 0)
+ def top_left_portstr(self):
+ return self.access_portstr(0, 0)
- def top_right_portstr(self):
- return self.access_portstr(-1, 0)
+ def top_right_portstr(self):
+ return self.access_portstr(-1, 0)
- def top_access_portstr(self, i_col):
- return self.access_portstr(i_col, 0)
+ def top_access_portstr(self, i_col):
+ return self.access_portstr(i_col, 0)
- def bottom_access_portstr(self, i_col):
- return self.access_portstr(i_col, -1)
+ def bottom_access_portstr(self, i_col):
+ return self.access_portstr(i_col, -1)
- def access_portstr(self, i_col, i_row):
- return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row)
+ def access_portstr(self, i_col, i_row):
+ return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row)
- def internal_portstr(self, i_col, i_row):
- if i_col < 0:
- i_col = self.numcols + i_col
- if i_row < 0:
- i_row = self.numrows + i_row
- return "_col"+str(i_col)+"_row"+str(i_row)
+ def internal_portstr(self, i_col, i_row):
+ if i_col < 0:
+ i_col = self.numcols + i_col
+ if i_row < 0:
+ i_row = self.numrows + i_row
+ return "_col"+str(i_col)+"_row"+str(i_row)
- def add_to_graph(self, g):
- if self.empty:
- pass
- else:
- g.node(name=self.ref_in_g, shape='none',
- margin='0', label=self.to_str())
+ def add_to_graph(self, g):
+ if self.empty:
+ pass
+ else:
+ g.node(name=self.ref_in_g, shape='none',
+ margin='0', label=self.to_str())
class Table:
- def __init__(self, seqs_by_rowtype, ref_in_g, rowtype_order=[]):
- self.ref_in_g = ref_in_g
- # consistent presentation, and v useful for feedforward clarity
- self.rows = []
- self.seq_index = {}
- if len(rowtype_order) > 1:
- self.add_rowtype_cell = True
- else:
- errnote = "table got multiple row types but no order for them"
- assert len(seqs_by_rowtype.keys()) == 1, errnote
- rowtype_order = list(seqs_by_rowtype.keys())
- self.add_rowtype_cell = not (rowtype_order[0] == RES)
- self.note_res_dependencies = len(seqs_by_rowtype.get(RES, [])) > 1
- self.leading_metadata_offset = 1 + self.add_rowtype_cell
- for rt in rowtype_order:
- seqs = sorted(seqs_by_rowtype[rt],
- key=lambda seq: seq.creation_order_id)
- for i, seq in enumerate(seqs):
- # each one appends to self.rows.
- self.n = self.add_row(seq, rt)
- # self.n stores length of a single row, they will all be the
- # same, just easiest to get like this
- # add_row has to happen one at a time b/c they care about
- # length of self.rows at time of addition (to get ports right)
- self.empty = len(self.rows) == 0
- if self.empty:
- self.n = 0
- # (len(rowtype_order)==1 and rowtype_order[0]==QVAR)
- self.transpose = False
- # no need to twist Q, just making the table under anyway
- # transpose affects the port accesses, but think about that later
-
- def to_str(self):
- rows = self.rows if not self.transpose else list(zip(*self.rows))
-
- def cells2row(cells):
- return ''+''.join(cells)+'
'
- return '<' + ''.join(map(cells2row, rows)) \
- + '
>'
-
- def bottom_left_portstr(self):
- return self.access_portstr(0, -1)
-
- def bottom_right_portstr(self):
- return self.access_portstr(-1, -1)
-
- def top_left_portstr(self):
- return self.access_portstr(0, 0)
-
- def top_right_portstr(self):
- return self.access_portstr(-1, 0)
-
- def top_access_portstr(self, i_col, skip_meta=False):
- return self.access_portstr(i_col, 0, skip_meta=skip_meta)
-
- def bottom_access_portstr(self, i_col, skip_meta=False):
- return self.access_portstr(i_col, -1, skip_meta=skip_meta)
-
- def access_portstr(self, i_col, i_row, skip_meta=False):
- return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row,
- skip_meta=skip_meta)
-
- def internal_portstr(self, i_col, i_row, skip_meta=False):
- if skip_meta and (i_col >= 0): # before flip things for reverse column
- # access
- i_col += self.leading_metadata_offset
- if i_col < 0:
- i_col = (self.n) + i_col
- if i_row < 0:
- i_row = len(self.rows) + i_row
- return "_col"+str(i_col)+"_row"+str(i_row)
-
- def add_row(self, seq, row_type):
- def add_cell(val, colour):
- res = CellVals(val, colour, -1, -1,
- known_portstr=self.internal_portstr(len(cells),
- len(self.rows)))
- cells.append(str(res))
-
- def add_strong_line():
- # after failing to inject css styles in graphviz,
- # seeing that their suggestion only creates lines
- # (if at all? unclear) of width 1
- # (same as the border already there) and it wont make multiple VRs,
- # and realising their suggestion also does nothing,
- # refer to hack at the top of this priceless page:
- # http://jkorpela.fi/html/cellborder.html
- cells.append(' | ')
-
- qkvr_colour, name_colour, data_colour = colour_scheme(row_type)
- cells = [] # has to be created in advance, and not just be all the
- # results of add_cell, because add_cell cares about current length of
- # 'cells'
- if self.add_rowtype_cell:
- add_cell(ROW_NAMES[row_type], qkvr_colour)
- add_cell(seq.name, name_colour)
- for v in seq.last_val():
- add_cell(v, data_colour)
- if self.note_res_dependencies:
- self.seq_index[seq] = len(self.rows)
- add_strong_line()
- add_cell("("+str(self.seq_index[seq])+")", indices_colour)
- add_cell(self.dependencies_str(seq, row_type), comment_colour)
- self.rows.append(cells)
- return len(cells)
-
- def dependencies_str(self, seq, row_type):
- if not row_type == RES:
- return ""
- return "from ("+", ".join(str(self.seq_index[m]) for m in
- seq.get_nonminor_parent_sequences()) + ")"
-
- def add_to_graph(self, g):
- if self.empty:
- # g.node(name=self.ref_in_g,label="empty table")
- pass
- else:
- g.node(name=self.ref_in_g, shape='none',
- margin='0', label=self.to_str())
+ def __init__(self, seqs_by_rowtype, ref_in_g, rowtype_order=[]):
+ self.ref_in_g = ref_in_g
+ # consistent presentation, and v useful for feedforward clarity
+ self.rows = []
+ self.seq_index = {}
+ if len(rowtype_order) > 1:
+ self.add_rowtype_cell = True
+ else:
+ errnote = "table got multiple row types but no order for them"
+ assert len(seqs_by_rowtype.keys()) == 1, errnote
+ rowtype_order = list(seqs_by_rowtype.keys())
+ self.add_rowtype_cell = not (rowtype_order[0] == RES)
+ self.note_res_dependencies = len(seqs_by_rowtype.get(RES, [])) > 1
+ self.leading_metadata_offset = 1 + self.add_rowtype_cell
+ for rt in rowtype_order:
+ seqs = sorted(seqs_by_rowtype[rt],
+ key=lambda seq: seq.creation_order_id)
+ for i, seq in enumerate(seqs):
+ # each one appends to self.rows.
+ self.n = self.add_row(seq, rt)
+ # self.n stores length of a single row, they will all be the
+ # same, just easiest to get like this
+ # add_row has to happen one at a time b/c they care about
+ # length of self.rows at time of addition (to get ports right)
+ self.empty = len(self.rows) == 0
+ if self.empty:
+ self.n = 0
+ # (len(rowtype_order)==1 and rowtype_order[0]==QVAR)
+ self.transpose = False
+ # no need to twist Q, just making the table under anyway
+ # transpose affects the port accesses, but think about that later
+
+ def to_str(self):
+ rows = self.rows if not self.transpose else list(zip(*self.rows))
+
+ def cells2row(cells):
+ return ''+''.join(cells)+'
'
+ return '<' + ''.join(map(cells2row, rows)) \
+ + '
>'
+
+ def bottom_left_portstr(self):
+ return self.access_portstr(0, -1)
+
+ def bottom_right_portstr(self):
+ return self.access_portstr(-1, -1)
+
+ def top_left_portstr(self):
+ return self.access_portstr(0, 0)
+
+ def top_right_portstr(self):
+ return self.access_portstr(-1, 0)
+
+ def top_access_portstr(self, i_col, skip_meta=False):
+ return self.access_portstr(i_col, 0, skip_meta=skip_meta)
+
+ def bottom_access_portstr(self, i_col, skip_meta=False):
+ return self.access_portstr(i_col, -1, skip_meta=skip_meta)
+
+ def access_portstr(self, i_col, i_row, skip_meta=False):
+ return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row,
+ skip_meta=skip_meta)
+
+ def internal_portstr(self, i_col, i_row, skip_meta=False):
+ if skip_meta and (i_col >= 0): # before flip things for reverse column
+ # access
+ i_col += self.leading_metadata_offset
+ if i_col < 0:
+ i_col = (self.n) + i_col
+ if i_row < 0:
+ i_row = len(self.rows) + i_row
+ return "_col"+str(i_col)+"_row"+str(i_row)
+
+ def add_row(self, seq, row_type):
+ def add_cell(val, colour):
+ res = CellVals(val, colour, -1, -1,
+ known_portstr=self.internal_portstr(len(cells),
+ len(self.rows)))
+ cells.append(str(res))
+
+ def add_strong_line():
+ # after failing to inject css styles in graphviz,
+ # seeing that their suggestion only creates lines
+ # (if at all? unclear) of width 1
+ # (same as the border already there) and it wont make multiple VRs,
+ # and realising their suggestion also does nothing,
+ # refer to hack at the top of this priceless page:
+ # http://jkorpela.fi/html/cellborder.html
+ cells.append(' | ')
+
+ qkvr_colour, name_colour, data_colour = colour_scheme(row_type)
+ cells = [] # has to be created in advance, and not just be all the
+ # results of add_cell, because add_cell cares about current length of
+ # 'cells'
+ if self.add_rowtype_cell:
+ add_cell(ROW_NAMES[row_type], qkvr_colour)
+ add_cell(seq.name, name_colour)
+ for v in seq.last_val():
+ add_cell(v, data_colour)
+ if self.note_res_dependencies:
+ self.seq_index[seq] = len(self.rows)
+ add_strong_line()
+ add_cell("("+str(self.seq_index[seq])+")", indices_colour)
+ add_cell(self.dependencies_str(seq, row_type), comment_colour)
+ self.rows.append(cells)
+ return len(cells)
+
+ def dependencies_str(self, seq, row_type):
+ if not row_type == RES:
+ return ""
+ return "from ("+", ".join(str(self.seq_index[m]) for m in
+ seq.get_nonminor_parent_sequences()) + ")"
+
+ def add_to_graph(self, g):
+ if self.empty:
+ # g.node(name=self.ref_in_g,label="empty table")
+ pass
+ else:
+ g.node(name=self.ref_in_g, shape='none',
+ margin='0', label=self.to_str())
def place_above(g, node1, node2):
- g.edge(node1.bottom_left_portstr(), node2.top_left_portstr(),
- style="invis")
- g.edge(node1.bottom_right_portstr(),
- node2.top_right_portstr(), style="invis")
+ g.edge(node1.bottom_left_portstr(), node2.top_left_portstr(),
+ style="invis")
+ g.edge(node1.bottom_right_portstr(),
+ node2.top_right_portstr(), style="invis")
def connect(g, top_table, bottom_table, select_vals):
- # connects top_table as k and bottom_table as q
- if top_table.empty or bottom_table.empty:
- return # not doing this for now
- place_above(g, top_table, bottom_table)
- # just to position them one on top of the other, even if select is empty
- for q_i in select_vals:
- for k_i, b in enumerate(select_vals[q_i]):
- if b:
- # have to add 2 cause first 2 are data type and row name
- g.edge(top_table.bottom_access_portstr(k_i, skip_meta=True),
- bottom_table.top_access_portstr(q_i, skip_meta=True),
- arrowhead='none')
+ # connects top_table as k and bottom_table as q
+ if top_table.empty or bottom_table.empty:
+ return # not doing this for now
+ place_above(g, top_table, bottom_table)
+ # just to position them one on top of the other, even if select is empty
+ for q_i in select_vals:
+ for k_i, b in enumerate(select_vals[q_i]):
+ if b:
+ # have to add 2 cause first 2 are data type and row name
+ g.edge(top_table.bottom_access_portstr(k_i, skip_meta=True),
+ bottom_table.top_access_portstr(q_i, skip_meta=True),
+ arrowhead='none')
class SubHead:
- def __init__(self, name, seq):
- vvars = seq.get_immediate_parent_sequences()
- if not seq.definitely_uses_identity_function:
- vreal = seq.pre_aggregate_comp()
- vreal(seq.last_w) # run it on same w to fill with right results
- vreals = [vreal]
- else:
- vreals = []
-
- self.name = name
- self.vvars_table = Table(
- {VVAR: vvars, VREAL: vreals}, self.name+"_vvars",
- rowtype_order=[VVAR, VREAL])
- self.res_table = Table({RES: [seq]}, self.name+"_res")
- self.default = "default: " + \
- str(seq.default) if seq.default is not None else ""
- # self.vreals_table = ## ? add partly processed vals, useful for eg
- # conditioned_contains?
-
- def add_to_graph(self, g):
- self.vvars_table.add_to_graph(g)
- self.res_table.add_to_graph(g)
- if self.default:
- g.node(self.name+"_default", shape='rectangle', label=self.default)
- g.edge(self.name+"_default", self.res_table.top_left_portstr(),
- arrowhead='none')
-
- def add_edges(self, g, select_vals):
- connect(g, self.vvars_table, self.res_table, select_vals)
-
- def bottom_left_portstr(self):
- return self.res_table.bottom_left_portstr()
-
- def bottom_right_portstr(self):
- return self.res_table.bottom_right_portstr()
-
- def top_left_portstr(self):
- return self.vvars_table.top_left_portstr()
-
- def top_right_portstr(self):
- return self.vvars_table.top_right_portstr()
+ def __init__(self, name, seq):
+ vvars = seq.get_immediate_parent_sequences()
+ if not seq.definitely_uses_identity_function:
+ vreal = seq.pre_aggregate_comp()
+ vreal(seq.last_w) # run it on same w to fill with right results
+ vreals = [vreal]
+ else:
+ vreals = []
+
+ self.name = name
+ self.vvars_table = Table(
+ {VVAR: vvars, VREAL: vreals}, self.name+"_vvars",
+ rowtype_order=[VVAR, VREAL])
+ self.res_table = Table({RES: [seq]}, self.name+"_res")
+ self.default = "default: " + \
+ str(seq.default) if seq.default is not None else ""
+ # self.vreals_table = ## ? add partly processed vals, useful for eg
+ # conditioned_contains?
+
+ def add_to_graph(self, g):
+ self.vvars_table.add_to_graph(g)
+ self.res_table.add_to_graph(g)
+ if self.default:
+ g.node(self.name+"_default", shape='rectangle', label=self.default)
+ g.edge(self.name+"_default", self.res_table.top_left_portstr(),
+ arrowhead='none')
+
+ def add_edges(self, g, select_vals):
+ connect(g, self.vvars_table, self.res_table, select_vals)
+
+ def bottom_left_portstr(self):
+ return self.res_table.bottom_left_portstr()
+
+ def bottom_right_portstr(self):
+ return self.res_table.bottom_right_portstr()
+
+ def top_left_portstr(self):
+ return self.vvars_table.top_left_portstr()
+
+ def top_right_portstr(self):
+ return self.vvars_table.top_right_portstr()
class Head:
- def __init__(self, name, head_primitives, i):
- self.name = name
- self.i = i
- self.head_primitives = head_primitives
- select = self.head_primitives.select
- q_vars, k_vars = select.q_vars, select.k_vars
- q_vars = sorted(list(set(q_vars)), key=lambda a: a.creation_order_id)
- k_vars = sorted(list(set(k_vars)), key=lambda a: a.creation_order_id)
- self.kq_table = Table({QVAR: q_vars, KVAR: k_vars},
- self.name+"_qvars", rowtype_order=[KVAR, QVAR])
- # self.k_table = Table({KVAR:k_vars},self.name+"_kvars")
- self.select_result_table = makeQKStable(
- q_vars, k_vars, select, self.name+"_select")
- # self.select_table = SelectTable(self.head_primitives.select,
- # self.name+"_select")
- self.subheads = [SubHead(self.name+"_subcomp_"+str(i), seq)
- for i, seq in
- enumerate(self.head_primitives.sequences)]
-
- def add_to_graph(self, g):
- with g.subgraph(name=self.name) as head:
- def headlabel():
- # return self.head_primitives.select.name
- return 'head '+str(self.i) +\
- "\n("+self.head_primitives.select.name+")"
- head.attr(fillcolor=head_color, label=headlabel(),
- fontcolor='black', style='filled')
- with head.subgraph(name=self.name+"_select_parts") as sel:
- sel.attr(rankdir="LR", label="", style="invis", rank="same")
- if True: # not (self.kq_table.empty):
- self.select_result_table.add_to_graph(sel)
- self.kq_table.add_to_graph(sel)
- # sel.edge(self.kq_table.bottom_right_portstr(),
- # self.select_result_table.bottom_left_portstr(),style="invis")
-
- [s.add_to_graph(head) for s in self.subheads]
-
- def add_organising_edges(self, g):
- if self.kq_table.empty:
- return
- for s in self.subheads:
- place_above(g, self.select_result_table, s)
-
- def bottom_left_portstr(self):
- return self.subheads[0].bottom_left_portstr()
-
- def bottom_right_portstr(self):
- return self.subheads[-1].bottom_right_portstr()
-
- def top_left_portstr(self):
- if not (self.kq_table.empty):
- return self.kq_table.top_left_portstr()
- else: # no kq (and so no select either) table. go into subheads
- return self.subheads[0].top_left_portstr()
-
- def top_right_portstr(self):
- if not (self.kq_table.empty):
- return self.kq_table.top_right_portstr()
- else:
- return self.subheads[-1].top_right_portstr()
-
- def add_edges(self, g):
- select_vals = self.head_primitives.select.last_val()
- # connect(g,self.k_table,self.q_table,select_vals)
- for s in self.subheads:
- s.add_edges(g, select_vals)
- self.add_organising_edges(g)
+ def __init__(self, name, head_primitives, i):
+ self.name = name
+ self.i = i
+ self.head_primitives = head_primitives
+ select = self.head_primitives.select
+ q_vars, k_vars = select.q_vars, select.k_vars
+ q_vars = sorted(list(set(q_vars)), key=lambda a: a.creation_order_id)
+ k_vars = sorted(list(set(k_vars)), key=lambda a: a.creation_order_id)
+ self.kq_table = Table({QVAR: q_vars, KVAR: k_vars},
+ self.name+"_qvars", rowtype_order=[KVAR, QVAR])
+ # self.k_table = Table({KVAR:k_vars},self.name+"_kvars")
+ self.select_result_table = makeQKStable(
+ q_vars, k_vars, select, self.name+"_select")
+ # self.select_table = SelectTable(self.head_primitives.select,
+ # self.name+"_select")
+ self.subheads = [SubHead(self.name+"_subcomp_"+str(i), seq)
+ for i, seq in
+ enumerate(self.head_primitives.sequences)]
+
+ def add_to_graph(self, g):
+ with g.subgraph(name=self.name) as head:
+ def headlabel():
+ # return self.head_primitives.select.name
+ return 'head '+str(self.i) +\
+ "\n("+self.head_primitives.select.name+")"
+ head.attr(fillcolor=head_color, label=headlabel(),
+ fontcolor='black', style='filled')
+ with head.subgraph(name=self.name+"_select_parts") as sel:
+ sel.attr(rankdir="LR", label="", style="invis", rank="same")
+ if True: # not (self.kq_table.empty):
+ self.select_result_table.add_to_graph(sel)
+ self.kq_table.add_to_graph(sel)
+ # sel.edge(self.kq_table.bottom_right_portstr(),
+ # self.select_result_table.bottom_left_portstr(),style="invis")
+
+ [s.add_to_graph(head) for s in self.subheads]
+
+ def add_organising_edges(self, g):
+ if self.kq_table.empty:
+ return
+ for s in self.subheads:
+ place_above(g, self.select_result_table, s)
+
+ def bottom_left_portstr(self):
+ return self.subheads[0].bottom_left_portstr()
+
+ def bottom_right_portstr(self):
+ return self.subheads[-1].bottom_right_portstr()
+
+ def top_left_portstr(self):
+ if not (self.kq_table.empty):
+ return self.kq_table.top_left_portstr()
+ else: # no kq (and so no select either) table. go into subheads
+ return self.subheads[0].top_left_portstr()
+
+ def top_right_portstr(self):
+ if not (self.kq_table.empty):
+ return self.kq_table.top_right_portstr()
+ else:
+ return self.subheads[-1].top_right_portstr()
+
+ def add_edges(self, g):
+ select_vals = self.head_primitives.select.last_val()
+ # connect(g,self.k_table,self.q_table,select_vals)
+ for s in self.subheads:
+ s.add_edges(g, select_vals)
+ self.add_organising_edges(g)
def contains_tokens(mvs):
- return next((True for mv in mvs if guarded_contains(base_tokens, mv)),
- False)
+ return next((True for mv in mvs if guarded_contains(base_tokens, mv)),
+ False)
def just_base_sequence_fix(d_ffs, ff_parents):
- # when there are no parents and only one ff, then we are actually just
- # looking at the indices/tokens by themselves. in this case, putting that
- # ff in as a parent (with no child) makes the layer draw it properly
- if not ff_parents and len(d_ffs) == 1:
- return ff_parents, d_ffs
- return d_ffs, ff_parents
+ # when there are no parents and only one ff, then we are actually just
+ # looking at the indices/tokens by themselves. in this case, putting that
+ # ff in as a parent (with no child) makes the layer draw it properly
+ if not ff_parents and len(d_ffs) == 1:
+ return ff_parents, d_ffs
+ return d_ffs, ff_parents
class Layer:
- def __init__(self, depth, d_heads, d_ffs, add_tokens_on_ff=False):
- self.heads = []
- self.depth = depth
- self.name = self.layer_cluster_name(depth)
- for i, h in enumerate(d_heads):
- self.heads.append(Head(self.name+"_head"+str(i), h, i))
- ff_parents = []
- for ff in d_ffs:
- ff_parents += ff.get_nonminor_parent_sequences()
- ff_parents = list(set(ff_parents))
- ff_parents = [p for p in ff_parents if not guarded_contains(d_ffs, p)]
- d_ffs, ff_parents = just_base_sequence_fix(d_ffs, ff_parents)
- rows_by_type = {RES: d_ffs, VVAR: ff_parents}
- rowtype_order = [VVAR, RES]
- if add_tokens_on_ff and not contains_tokens(ff_parents):
- rows_by_type[INPUT] = [tokens_asis]
- rowtype_order = [INPUT] + rowtype_order
- self.ff_table = Table(rows_by_type, self.name+"_ffs", rowtype_order)
-
- def bottom_object(self):
- if not self.ff_table.empty:
- return self.ff_table
- else:
- return self.heads[-1]
-
- def top_object(self):
- if self.heads:
- return self.heads[0]
- else:
- return self.ff_table
-
- def bottom_left_portstr(self):
- return self.bottom_object().bottom_left_portstr()
-
- def bottom_right_portstr(self):
- return self.bottom_object().bottom_right_portstr()
-
- def top_left_portstr(self):
- return self.top_object().top_left_portstr()
-
- def top_right_portstr(self):
- return self.top_object().top_right_portstr()
-
- def add_to_graph(self, g):
- with g.subgraph(name=self.name) as lg:
- lg.attr(fillcolor=layer_color, label='layer '+str(self.depth),
- fontcolor='black', style='filled')
- for h in self.heads:
- h.add_to_graph(lg)
- self.ff_table.add_to_graph(lg)
-
- def add_organising_edges(self, g):
- if self.ff_table.empty:
- return
- for h in self.heads:
- place_above(g, h, self.ff_table)
-
- def add_edges(self, g):
- for h in self.heads:
- h.add_edges(g)
- self.add_organising_edges(g)
-
- def layer_cluster_name(self, depth):
- return 'cluster_l'+str(depth) # graphviz needs
- # cluster names to start with 'cluster'
+ def __init__(self, depth, d_heads, d_ffs, add_tokens_on_ff=False):
+ self.heads = []
+ self.depth = depth
+ self.name = self.layer_cluster_name(depth)
+ for i, h in enumerate(d_heads):
+ self.heads.append(Head(self.name+"_head"+str(i), h, i))
+ ff_parents = []
+ for ff in d_ffs:
+ ff_parents += ff.get_nonminor_parent_sequences()
+ ff_parents = list(set(ff_parents))
+ ff_parents = [p for p in ff_parents if not guarded_contains(d_ffs, p)]
+ d_ffs, ff_parents = just_base_sequence_fix(d_ffs, ff_parents)
+ rows_by_type = {RES: d_ffs, VVAR: ff_parents}
+ rowtype_order = [VVAR, RES]
+ if add_tokens_on_ff and not contains_tokens(ff_parents):
+ rows_by_type[INPUT] = [tokens_asis]
+ rowtype_order = [INPUT] + rowtype_order
+ self.ff_table = Table(rows_by_type, self.name+"_ffs", rowtype_order)
+
+ def bottom_object(self):
+ if not self.ff_table.empty:
+ return self.ff_table
+ else:
+ return self.heads[-1]
+
+ def top_object(self):
+ if self.heads:
+ return self.heads[0]
+ else:
+ return self.ff_table
+
+ def bottom_left_portstr(self):
+ return self.bottom_object().bottom_left_portstr()
+
+ def bottom_right_portstr(self):
+ return self.bottom_object().bottom_right_portstr()
+
+ def top_left_portstr(self):
+ return self.top_object().top_left_portstr()
+
+ def top_right_portstr(self):
+ return self.top_object().top_right_portstr()
+
+ def add_to_graph(self, g):
+ with g.subgraph(name=self.name) as lg:
+ lg.attr(fillcolor=layer_color, label='layer ' + str(self.depth),
+ fontcolor='black', style='filled')
+ for h in self.heads:
+ h.add_to_graph(lg)
+ self.ff_table.add_to_graph(lg)
+
+ def add_organising_edges(self, g):
+ if self.ff_table.empty:
+ return
+ for h in self.heads:
+ place_above(g, h, self.ff_table)
+
+ def add_edges(self, g):
+ for h in self.heads:
+ h.add_edges(g)
+ self.add_organising_edges(g)
+
+ def layer_cluster_name(self, depth):
+ return 'cluster_l'+str(depth) # graphviz needs
+ # cluster names to start with 'cluster'
class CompFlow:
- def __init__(self, all_heads, all_ffs, force_vertical_layers,
- add_tokens_on_ff=False):
- self.force_vertical_layers = force_vertical_layers
- self.add_tokens_on_ff = add_tokens_on_ff
- self.make_all_layers(all_heads, all_ffs)
-
- def make_all_layers(self, all_heads, all_ffs):
- self.layers = []
- ff_depths = [seq.scheduled_comp_depth for seq in all_ffs]
- head_depths = [h.comp_depth for h in all_heads]
- depths = sorted(list(set(ff_depths+head_depths)))
- for d in depths:
- d_heads = [h for h in all_heads if h.comp_depth == d]
- d_heads = sorted(d_heads, key=lambda h: h.select.creation_order_id)
- # only important for determinism to help debug
- d_ffs = [f for f in all_ffs if f.scheduled_comp_depth == d]
- self.layers.append(Layer(d, d_heads, d_ffs, self.add_tokens_on_ff))
-
- def add_all_layers(self, g):
- [layer.add_to_graph(g) for layer in self.layers]
-
- def add_organising_edges(self, g):
- if self.force_vertical_layers:
- for l1, l2 in zip(self.layers, self.layers[1:]):
- place_above(g, l1, l2)
-
- def add_edges(self, g):
- self.add_organising_edges(g)
- [layer.add_edges(g) for layer in self.layers]
+ def __init__(self, all_heads, all_ffs, force_vertical_layers,
+ add_tokens_on_ff=False):
+ self.force_vertical_layers = force_vertical_layers
+ self.add_tokens_on_ff = add_tokens_on_ff
+ self.make_all_layers(all_heads, all_ffs)
+
+ def make_all_layers(self, all_heads, all_ffs):
+ self.layers = []
+ ff_depths = [seq.scheduled_comp_depth for seq in all_ffs]
+ head_depths = [h.comp_depth for h in all_heads]
+ depths = sorted(list(set(ff_depths+head_depths)))
+ for d in depths:
+ d_heads = [h for h in all_heads if h.comp_depth == d]
+ d_heads = sorted(d_heads, key=lambda h: h.select.creation_order_id)
+ # only important for determinism to help debug
+ d_ffs = [f for f in all_ffs if f.scheduled_comp_depth == d]
+ self.layers.append(Layer(d, d_heads, d_ffs, self.add_tokens_on_ff))
+
+ def add_all_layers(self, g):
+ [layer.add_to_graph(g) for layer in self.layers]
+
+ def add_organising_edges(self, g):
+ if self.force_vertical_layers:
+ for l1, l2 in zip(self.layers, self.layers[1:]):
+ place_above(g, l1, l2)
+
+ def add_edges(self, g):
+ self.add_organising_edges(g)
+ [layer.add_edges(g) for layer in self.layers]
@UnfinishedFunc
def draw_comp_flow(self, w, filename=None,
- keep_dot=False, show=True,
- force_vertical_layers=True, add_tokens_on_ff=False):
- if w is not None:
- self.call(w) # execute seq (and all its ancestors) on the given input
- if not self.last_w == w:
- print("evaluating input failed")
- return
- else: # if w == None, assume seq has already been executed on some input.
- w = self.last_w
- if None is filename:
- name = self.name
- filename = os.path.join("comp_flows", windows_path_cleaner(
- name+"("+(str(w) if not isinstance(w, str) else "\""+w+"\"")+")"))
- self.mark_all_minor_ancestors()
- self.make_display_names_for_all_parents(skip_minors=True)
-
- all_heads, all_ffs = self.get_all_ancestor_heads_and_ffs(
- remove_minors=True)
- # this scheduling also marks the analysis parent selects
- compflow = CompFlow(all_heads, all_ffs,
- force_vertical_layers=force_vertical_layers,
- add_tokens_on_ff=add_tokens_on_ff)
-
- # only import graphviz *inside* this function -
- # that way RASP can run even if graphviz setup fails
- # (though it will not be able to draw computation flows without it)
- from graphviz import Digraph
- g = Digraph('g')
- # with curved lines it fusses over separating score edges
- # and makes weirdly curved ones that start overlapping with the sequences
- # :(
- g.attr(splines='polyline')
- compflow.add_all_layers(g)
- compflow.add_edges(g)
- g.render(filename=filename)
- if show:
- g.view()
- if not keep_dot:
- os.remove(filename)
+ keep_dot=False, show=True,
+ force_vertical_layers=True, add_tokens_on_ff=False):
+ if w is not None:
+ self.call(w) # execute seq (and all its ancestors) on the given input
+ if not self.last_w == w:
+ print("evaluating input failed")
+ return
+ else: # if w == None, assume seq has already been executed on some input.
+ w = self.last_w
+ if None is filename:
+ name = self.name
+ filename = os.path.join("comp_flows", windows_path_cleaner(
+ name+"("+(str(w) if not isinstance(w, str) else "\""+w+"\"")+")"))
+ self.mark_all_minor_ancestors()
+ self.make_display_names_for_all_parents(skip_minors=True)
+
+ all_heads, all_ffs = self.get_all_ancestor_heads_and_ffs(
+ remove_minors=True)
+ # this scheduling also marks the analysis parent selects
+ compflow = CompFlow(all_heads, all_ffs,
+ force_vertical_layers=force_vertical_layers,
+ add_tokens_on_ff=add_tokens_on_ff)
+
+ # only import graphviz *inside* this function -
+ # that way RASP can run even if graphviz setup fails
+ # (though it will not be able to draw computation flows without it)
+ from graphviz import Digraph
+ g = Digraph('g')
+ # with curved lines it fusses over separating score edges
+ # and makes weirdly curved ones that start overlapping with the sequences
+ # :(
+ g.attr(splines='polyline')
+ compflow.add_all_layers(g)
+ compflow.add_edges(g)
+ g.render(filename=filename)
+ if show:
+ g.view()
+ if not keep_dot:
+ os.remove(filename)
dummyimport = None
diff --git a/RASP_support/Environment.py b/RASP_support/Environment.py
index 0dddc1e..e8896a8 100644
--- a/RASP_support/Environment.py
+++ b/RASP_support/Environment.py
@@ -1,97 +1,97 @@
from .FunctionalSupport import Unfinished, RASPTypeError, tokens_asis, \
- tokens_str, tokens_int, tokens_bool, tokens_float, indices
+ tokens_str, tokens_int, tokens_bool, tokens_float, indices
from .Evaluator import RASPFunction
class UndefinedVariable(Exception):
- def __init__(self, varname):
- super().__init__("Error: Undefined variable: "+varname)
+ def __init__(self, varname):
+ super().__init__("Error: Undefined variable: "+varname)
class ReservedName(Exception):
- def __init__(self, varname):
- super().__init__("Error: Cannot set reserved name: "+varname)
+ def __init__(self, varname):
+ super().__init__("Error: Cannot set reserved name: "+varname)
class Environment:
- def __init__(self, parent_env=None, name=None, stealing_env=None):
- self.variables = {}
- self.name = name
- self.parent_env = parent_env
- self.stealing_env = stealing_env
- self.base_setup() # nested envs can have them too. makes life simpler,
- # instead of checking if they have the constant_variables etc in get.
- # bit heavier on memory but no one's going to use this language for big
- # nested stuff anyway
- self.storing_in_constants = False
+ def __init__(self, parent_env=None, name=None, stealing_env=None):
+ self.variables = {}
+ self.name = name
+ self.parent_env = parent_env
+ self.stealing_env = stealing_env
+ self.base_setup() # nested envs can have them too. makes life simpler,
+ # instead of checking if they have the constant_variables etc in get.
+ # bit heavier on memory but no one's going to use this language for big
+ # nested stuff anyway
+ self.storing_in_constants = False
- def base_setup(self):
- self.constant_variables = {"tokens_asis": tokens_asis,
- "tokens_str": tokens_str,
- "tokens_int": tokens_int,
- "tokens_bool": tokens_bool,
- "tokens_float": tokens_float,
- "indices": indices,
- "True": True,
- "False": False}
- self.reserved_words = ["if", "else", "not", "and", "or", "out", "def",
- "return", "range", "for", "in", "zip", "len",
- "get"] + list(self.constant_variables.keys())
+ def base_setup(self):
+ self.constant_variables = {"tokens_asis": tokens_asis,
+ "tokens_str": tokens_str,
+ "tokens_int": tokens_int,
+ "tokens_bool": tokens_bool,
+ "tokens_float": tokens_float,
+ "indices": indices,
+ "True": True,
+ "False": False}
+ self.reserved_words = ["if", "else", "not", "and", "or", "out", "def",
+ "return", "range", "for", "in", "zip", "len",
+ "get"] + list(self.constant_variables.keys())
- def snapshot(self):
- res = Environment(parent_env=self.parent_env,
- name=self.name, stealing_env=self.stealing_env)
+ def snapshot(self):
+ res = Environment(parent_env=self.parent_env,
+ name=self.name, stealing_env=self.stealing_env)
- def carefulcopy(val):
- if isinstance(val, Unfinished) or isinstance(val, RASPFunction):
- return val # non mutable, at least not through rasp commands
- elif isinstance(val, float) or isinstance(val, int) \
- or isinstance(val, str) or isinstance(val, bool):
- return val # non mutable
- elif isinstance(val, list):
- return [carefulcopy(v) for v in val]
- else:
- raise RASPTypeError("environment contains element that is not "
- + "unfinished, rasp function, float, int,"
- + "string, bool, or list? :", val)
- res.constant_variables = {d: carefulcopy(
- self.constant_variables[d]) for d in self.constant_variables}
- res.variables = {d: carefulcopy(
- self.variables[d]) for d in self.variables}
- return res
+ def carefulcopy(val):
+ if isinstance(val, Unfinished) or isinstance(val, RASPFunction):
+ return val # non mutable, at least not through rasp commands
+ elif (isinstance(val, float) or isinstance(val, int) or
+ isinstance(val, str) or isinstance(val, bool)):
+ return val # non mutable
+ elif isinstance(val, list):
+ return [carefulcopy(v) for v in val]
+ else:
+ raise RASPTypeError("environment contains element that is " +
+ "not unfinished, rasp function, float, " +
+ "int, string, bool, or list? :", val)
+ res.constant_variables = {d: carefulcopy(
+ self.constant_variables[d]) for d in self.constant_variables}
+ res.variables = {d: carefulcopy(
+ self.variables[d]) for d in self.variables}
+ return res
- def make_nested(self, names_vars=[]):
- res = Environment(self, name=str(self.name)+"'")
- for n, v in names_vars:
- res.set_variable(n, v)
- return res
+ def make_nested(self, names_vars=[]):
+ res = Environment(self, name=str(self.name) + "'")
+ for n, v in names_vars:
+ res.set_variable(n, v)
+ return res
- def get_variable(self, name):
- if name in self.constant_variables:
- return self.constant_variables[name]
- if name in self.variables:
- return self.variables[name]
- if self.parent_env is not None:
- return self.parent_env.get_variable(name)
- raise UndefinedVariable(name)
+ def get_variable(self, name):
+ if name in self.constant_variables:
+ return self.constant_variables[name]
+ if name in self.variables:
+ return self.variables[name]
+ if self.parent_env is not None:
+ return self.parent_env.get_variable(name)
+ raise UndefinedVariable(name)
- def _set_checked_variable(self, name, val):
- if self.storing_in_constants:
- self.constant_variables[name] = val
- self.reserved_words.append(name)
- else:
- self.variables[name] = val
+ def _set_checked_variable(self, name, val):
+ if self.storing_in_constants:
+ self.constant_variables[name] = val
+ self.reserved_words.append(name)
+ else:
+ self.variables[name] = val
- def set_variable(self, name, val):
- if name in self.reserved_words:
- raise ReservedName(name)
+ def set_variable(self, name, val):
+ if name in self.reserved_words:
+ raise ReservedName(name)
- self._set_checked_variable(name, val)
- if self.stealing_env is not None:
- if name.startswith("_") or name == "out": # things we don't want
- # to steal
- return
- self.stealing_env.set_variable(name, val)
+ self._set_checked_variable(name, val)
+ if self.stealing_env is not None:
+ if name.startswith("_") or name == "out": # things we don't want
+ # to steal
+ return
+ self.stealing_env.set_variable(name, val)
- def set_out(self, val):
- self.variables["out"] = val
+ def set_out(self, val):
+ self.variables["out"] = val
diff --git a/RASP_support/Evaluator.py b/RASP_support/Evaluator.py
index e92f0e2..d30f6d2 100644
--- a/RASP_support/Evaluator.py
+++ b/RASP_support/Evaluator.py
@@ -1,6 +1,6 @@
from .FunctionalSupport import select, zipmap, aggregate, \
- or_selects, and_selects, not_select, indices, \
- Unfinished, UnfinishedSequence, UnfinishedSelect
+ or_selects, and_selects, not_select, indices, \
+ Unfinished, UnfinishedSequence, UnfinishedSelect
from .Sugar import tplor, tpland, tplnot, toseq, full_s
from .Support import RASPTypeError, RASPError
from collections.abc import Iterable
@@ -11,761 +11,766 @@
def strdesc(o, desc_cap=None):
- if isinstance(o, Unfinished):
- return o.name
- if isinstance(o, list):
- res = "["+", ".join([strdesc(v) for v in o])+"]"
- if desc_cap is not None and len(res) > desc_cap:
- return "(list)"
- else:
- return res
- if isinstance(o, dict):
- res = "{"+", ".join((strdesc(k)+": "+strdesc(o[k])) for k in o)+"}"
- if desc_cap is not None and len(res) > desc_cap:
- return "(dict)"
- else:
- return res
- else:
- if isinstance(o, str):
- return "\""+o+"\""
- else:
- return str(o)
+ if isinstance(o, Unfinished):
+ return o.name
+ if isinstance(o, list):
+ res = "[" + ", ".join([strdesc(v) for v in o]) + "]"
+ if desc_cap is not None and len(res) > desc_cap:
+ return "(list)"
+ else:
+ return res
+ if isinstance(o, dict):
+ res = "{" + \
+ ", ".join((strdesc(k) + ": " + strdesc(o[k])) for k in o) + "}"
+ if desc_cap is not None and len(res) > desc_cap:
+ return "(dict)"
+ else:
+ return res
+ else:
+ if isinstance(o, str):
+ return "\"" + o + "\""
+ else:
+ return str(o)
class RASPValueError(RASPError):
- def __init__(self, *a):
- super().__init__(*a)
+ def __init__(self, *a):
+ super().__init__(*a)
DEBUG = False
def debprint(*a, **kw):
- if DEBUG:
- print(*a, **kw)
+ if DEBUG:
+ print(*a, **kw)
def ast_text(ast): # just so don't have to go remembering this somewhere
- # consider seeing if can make it add spaces between the tokens when doing
- # this tho
- return ast.getText()
+ # consider seeing if can make it add spaces between the tokens when doing
+ # this tho
+ return ast.getText()
def isatom(v):
- # the legal atoms
- return True in [isinstance(v, t) for t in [int, float, str, bool]]
+ # the legal atoms
+ return True in [isinstance(v, t) for t in [int, float, str, bool]]
def name_general_type(v):
- if isinstance(v, list):
- return "list"
- if isinstance(v, dict):
- return "dict"
- if isinstance(v, UnfinishedSequence):
- return ENCODER_NAME
- if isinstance(v, UnfinishedSelect):
- return "selector"
- if isinstance(v, RASPFunction):
- return "function"
- if isatom(v):
- return "atom"
- return "??"
+ if isinstance(v, list):
+ return "list"
+ if isinstance(v, dict):
+ return "dict"
+ if isinstance(v, UnfinishedSequence):
+ return ENCODER_NAME
+ if isinstance(v, UnfinishedSelect):
+ return "selector"
+ if isinstance(v, RASPFunction):
+ return "function"
+ if isatom(v):
+ return "atom"
+ return "??"
class ArgsError(Exception):
- def __init__(self, name, expected, got):
- super().__init__("wrong number of args for "+name +
- "- expected: "+str(expected)+", got: "+str(got)+".")
+ def __init__(self, name, expected, got):
+ super().__init__("wrong number of args for " + name +
+ "- expected: " + str(expected) + ", got: " +
+ str(got) + ".")
class NamedVal:
- def __init__(self, name, val):
- self.name = name
- self.val = val
+ def __init__(self, name, val):
+ self.name = name
+ self.val = val
class NamedValList:
- def __init__(self, namedvals):
- self.nvs = namedvals
+ def __init__(self, namedvals):
+ self.nvs = namedvals
class JustVal:
- def __init__(self, val):
- self.val = val
+ def __init__(self, val):
+ self.val = val
class RASPFunction:
- def __init__(self, name, enclosing_env, argnames, statement_trees,
- returnexpr, creator_name):
- self.name = name # just for debug purposes
- self.enclosing_env = enclosing_env
- self.argnames = argnames
- self.statement_trees = statement_trees
- self.returnexpr = returnexpr
- self.creator = creator_name
-
- def __str__(self):
- return self.creator + " function: " + self.name \
- + "(" + ", ".join(self.argnames) + ")"
-
- def call(self, *args):
- top_eval = args[-1]
- args = args[:-1]
- # nesting, because function shouldn't affect the enclosing environment
- env = self.enclosing_env.make_nested([])
- if not len(args) == len(self.argnames):
- raise ArgsError(self.name, len(self.argnames), len(args))
- for n, v in zip(self.argnames, args):
- env.set_variable(n, v)
- evaluator = Evaluator(env, top_eval.repl)
- for at in self.statement_trees:
- evaluator.evaluate(at)
- res = evaluator.evaluateExprsList(self.returnexpr)
- return res[0] if len(res) == 1 else res
+ def __init__(self, name, enclosing_env, argnames, statement_trees,
+ returnexpr, creator_name):
+ self.name = name # just for debug purposes
+ self.enclosing_env = enclosing_env
+ self.argnames = argnames
+ self.statement_trees = statement_trees
+ self.returnexpr = returnexpr
+ self.creator = creator_name
+
+ def __str__(self):
+ return self.creator + " function: " + self.name \
+ + "(" + ", ".join(self.argnames) + ")"
+
+ def call(self, *args):
+ top_eval = args[-1]
+ args = args[:-1]
+ # nesting, because function shouldn't affect the enclosing environment
+ env = self.enclosing_env.make_nested([])
+ if not len(args) == len(self.argnames):
+ raise ArgsError(self.name, len(self.argnames), len(args))
+ for n, v in zip(self.argnames, args):
+ env.set_variable(n, v)
+ evaluator = Evaluator(env, top_eval.repl)
+ for at in self.statement_trees:
+ evaluator.evaluate(at)
+ res = evaluator.evaluateExprsList(self.returnexpr)
+ return res[0] if len(res) == 1 else res
class Evaluator:
- def __init__(self, env, repl):
- self.env = env
- self.sequence_running_example = repl.sequence_running_example
- self.backup_example = None
- # allows evaluating something that maybe doesn't necessarily work with
- # the main running example, but we just want to see what happens on
- # it - e.g. so we can do draw(tokens_int+1,[1,2]) without error even
- # while the main example is still "hello"
- self.repl = repl
-
- def evaluate(self, ast):
- if ast.expr():
- return self.evaluateExpr(ast.expr(), from_top=True)
- if ast.assign():
- return self.assign(ast.assign())
- if ast.funcDef():
- return self.funcDef(ast.funcDef())
- if ast.draw():
- return self.draw(ast.draw())
- if ast.forLoop():
- return self.forLoop(ast.forLoop())
- if ast.loadFile():
- return self.repl.loadFile(ast.loadFile(), self.env)
-
- # more to come
- raise NotImplementedError
-
- def draw(self, ast):
- # TODO: make at least some rudimentary comparisons of selectors somehow
- # to merge heads idk?????? maybe keep trace of operations used to
- # create them and those with exact same parent s-ops and operations
- # can get in? would still find eg select(0,0,==) and select(1,1,==)
- # different, but its better than nothing at all
- example = self.evaluateExpr(
- ast.inputseq) if ast.inputseq else self.sequence_running_example
- prev_backup = self.backup_example
- self.backup_example = example
- unf = self.evaluateExpr(ast.unf)
- if not isinstance(unf, UnfinishedSequence):
- raise RASPTypeError("draw expects unfinished sequence, got:", unf)
- unf.draw_comp_flow(example)
- res = unf.call(example)
- res.created_from_input = example
- self.backup_example = prev_backup
- return JustVal(res)
-
- def assign(self, ast):
- def set_val_and_name(val, name):
- self.env.set_variable(name, val)
- if isinstance(val, Unfinished):
- val.setname(name) # completely irrelevant really for the REPL,
- # but will help maintain sanity when printing computation flows
- return NamedVal(name, val)
-
- varnames = self._names_list(ast.var)
- values = self.evaluateExprsList(ast.val)
- if len(values) == 1:
- values = values[0]
-
- if len(varnames) == 1:
- return set_val_and_name(values, varnames[0])
- else:
- if not len(varnames) == len(values):
- raise RASPTypeError("expected", len(
- varnames), "values, but got:", len(values))
- reslist = []
- for v, name in zip(values, varnames):
- reslist.append(set_val_and_name(v, name))
- return NamedValList(reslist)
-
- def _names_list(self, ast):
- idsList = self._get_first_cont_list(ast)
- return [i.text for i in idsList]
-
- def _set_iterator_and_vals(self, iterator_names, iterator_vals):
- if len(iterator_names) == 1:
- self.env.set_variable(iterator_names[0], iterator_vals)
- elif isinstance(iterator_vals, Iterable) \
- and (len(iterator_vals) == len(iterator_names)):
- for n, v in zip(iterator_names, iterator_vals):
- self.env.set_variable(n, v)
- else:
- if not isinstance(iterator_vals, Iterable):
- raise RASPTypeError(
- "iterating with multiple iterator names, but got single"
- + " iterator value:", iterator_vals)
- else:
- # should work out by logic of last failed elif
- errnote = "something wrong with Evaluator logic"
- assert not (len(iterator_vals) == len(iterator_names)), errnote
- raise RASPTypeError("iterating with", len(iterator_names),
- "names but got", len(iterator_vals),
- "values (", iterator_vals, ")")
-
- def _evaluateDictComp(self, ast):
- ast = ast.dictcomp
- d = self.evaluateExpr(ast.iterable)
- if not (isinstance(d, list) or isinstance(d, dict)):
- raise RASPTypeError(
- "dict comprehension should have got a list or dict to loop "
- + "over, but got:", d)
- res = {}
- iterator_names = self._names_list(ast.iterator)
- for vals in d:
- orig_env = self.env
- self.env = self.env.make_nested()
- self._set_iterator_and_vals(iterator_names, vals)
- key = self.make_dict_key(ast.key)
- res[key] = self.evaluateExpr(ast.val)
- self.env = orig_env
- return res
-
- def _evaluateListComp(self, ast):
- ast = ast.listcomp
- ll = self.evaluateExpr(ast.iterable)
- if not (isinstance(ll, list) or isinstance(ll, dict)):
- raise RASPTypeError(
- "list comprehension should have got a list or dict to loop "
- + "over, but got:", ll)
- res = []
- iterator_names = self._names_list(ast.iterator)
- for vals in ll:
- orig_env = self.env
- self.env = self.env.make_nested()
- # sets inside the now-nested env -don't want to keep
- # the internal iterators after finishing this list comp
- self._set_iterator_and_vals(iterator_names, vals)
- res.append(self.evaluateExpr(ast.val))
- self.env = orig_env
- return res
-
- def forLoop(self, ast):
- iterator_names = self._names_list(ast.iterator)
- iterable = self.evaluateExpr(ast.iterable)
- if not (isinstance(iterable, list) or isinstance(iterable, dict)):
- raise RASPTypeError(
- "for loop needs to iterate over a list or dict, but got:",
- iterable)
- statements = self._get_first_cont_list(ast.mainbody)
- for vals in iterable:
- self._set_iterator_and_vals(iterator_names, vals)
- for s in statements:
- self.evaluate(s)
- return JustVal(None)
-
- def _get_first_cont_list(self, ast):
- res = []
- while ast:
- if ast.first:
- res.append(ast.first)
- # sometimes there's no first cause it's just eating a comment
- ast = ast.cont
- return res
-
- def funcDef(self, ast):
- funcname = ast.name.text
- argname_trees = self._get_first_cont_list(ast.arguments)
- argnames = [a.text for a in argname_trees]
- statement_trees = self._get_first_cont_list(ast.mainbody)
- returnexpr = ast.retstatement.res
- res = RASPFunction(funcname, self.env, argnames,
- statement_trees, returnexpr, self.env.name)
- self.env.set_variable(funcname, res)
- return NamedVal(funcname, res)
-
- def _evaluateUnaryExpr(self, ast):
- uexpr = self.evaluateExpr(ast.uexpr)
- uop = ast.uop.text
- if uop == "not":
- if isinstance(uexpr, UnfinishedSequence):
- return tplnot(uexpr)
- elif isinstance(uexpr, UnfinishedSelect):
- return not_select(uexpr)
- else:
- return not uexpr
- if uop == "-":
- return -uexpr
- if uop == "+":
- return +uexpr
- if uop == "round":
- return round(uexpr)
- if uop == "indicator":
- if isinstance(uexpr, UnfinishedSequence):
- name = "I("+uexpr.name+")"
- zipmapped = zipmap(uexpr, lambda a: 1 if a else 0, name=name)
- return zipmapped.allow_suppressing_display()
- # naming res makes interpreter think it is important, i.e.,
- # must always be displayed. but here it has only been named for
- # clarity, so correct it using .allow_suppressing_display()
-
- raise RASPTypeError(
- "indicator operator expects "+ENCODER_NAME+", got:", uexpr)
- raise NotImplementedError
-
- def _evaluateRange(self, ast):
- valsList = self.evaluateExprsList(ast.rangevals)
- if not len(valsList) in [1, 2, 3]:
- raise RASPTypeError(
- "wrong number of inputs to range, expected: 1, 2, or 3, got:",
- len(valsList))
- for v in valsList:
- if not isinstance(v, int):
- raise RASPTypeError(
- "range expects all integer inputs, but got:",
- strdesc(valsList))
- return list(range(*valsList))
-
- def _index_into_dict(self, d, index):
- def invalid_key_error(i):
- return RASPTypeError(
- f"index into dict has to be {ENCODER_NAME} or atom"
- + " (i.e., string, int, float, bool), got:", strdesc(i))
-
- def missing_key_error(i):
- return RASPValueError("index [", strdesc(i), "] not in dict.")
-
- dname, indexname = d.name, index.name
- d, index = d.val, index.val
-
- if isinstance(index, UnfinishedSequence):
- d = deepcopy(d)
- def apply_d(i):
- if i not in d:
- raise missing_key_error(i)
- return d[i]
- name = f"{dname}[{indexname}]"
- return zipmap((index,), apply_d, name=name)
- elif not isatom(index):
- raise invalid_key_error(index)
- if index not in d:
- raise missing_key_error(index)
- else:
- return d[index]
-
- def _index_into_list_or_str(self, ll, index):
- lname, indexname = ll.name, index.name
- ll, index = ll.val, index.val
- ltype = "list" if isinstance(ll, list) else "string"
-
- def invalid_key_error(i):
- return RASPTypeError(f"index into {ltype} has to be",
- f"{ENCODER_NAME} or integer, got:",
- strdesc(index))
-
- def check_and_raise_key_error(i):
- if i >= len(ll) or (-i) > len(ll):
- raise RASPValueError("index", index, "out of range for", ltype,
- "of length", len(ll))
-
- if isinstance(index, UnfinishedSequence):
- ll = deepcopy(ll)
- def apply_l(i):
- check_and_raise_key_error(i)
- return ll[i]
- name = f"{lname}[{indexname}]"
- return zipmap((index,), apply_l, name=name)
- elif not isinstance(index, int):
- raise invalid_key_error(index)
- check_and_raise_key_error(index)
- return ll[index]
-
- def _index_into_sequence(self, s, index):
- sname, indexname = s.name, index.name
- s, index = s.val, index.val
- if isinstance(index, int):
- if index >= 0:
- sel = select(toseq(index), indices, lambda q,
- k: q == k, name="load from "+str(index))
- else:
- length = self.env.get_variable("length")
- real_index = length + index
- real_index.setname(length.name+str(index))
- sel = select(real_index, indices, lambda q,
- k: q == k, name="load from "+str(index))
- agg = aggregate(sel, s, name=s.name+"["+str(index)+"]")
- return agg.allow_suppressing_display()
- else:
- raise RASPValueError(
- "index into sequence has to be integer, got:", strdesc(index))
-
- def _evaluateIndexing(self, ast):
- indexable = self.evaluateExpr(ast.indexable, get_name=True)
- index = self.evaluateExpr(ast.index, get_name=True)
-
-
- if isinstance(indexable.val, list) or isinstance(indexable.val, str):
- return self._index_into_list_or_str(indexable, index)
- elif isinstance(indexable.val, dict):
- return self._index_into_dict(indexable, index)
- elif isinstance(indexable.val, UnfinishedSequence):
- return self._index_into_sequence(indexable, index)
- else:
- raise RASPTypeError("can only index into a list, dict, string, or"
- + " sequence, but instead got:",
- strdesc(indexable.val))
-
- def _evaluateSelectExpr(self, ast):
- key = self.evaluateExpr(ast.key)
- query = self.evaluateExpr(ast.query)
- sop = ast.selop.text
- key = toseq(key) # in case got an atom in one of these,
- query = toseq(query) # e.g. selecting 0th index: indices @= 0
- if sop == "<":
- return select(query, key, lambda q, k: q > k)
- if sop == ">":
- return select(query, key, lambda q, k: q < k)
- if sop == "==":
- return select(query, key, lambda q, k: q == k)
- if sop == "!=":
- return select(query, key, lambda q, k: not (q == k))
- if sop == "<=":
- return select(query, key, lambda q, k: q >= k)
- if sop == ">=":
- return select(query, key, lambda q, k: q <= k)
-
- def _evaluateBinaryExpr(self, ast):
- def has_sequence(left, right):
- return isinstance(left, UnfinishedSequence) \
- or isinstance(right, UnfinishedSequence)
-
- def has_selector(left, right):
- return isinstance(left, UnfinishedSelect) \
- or isinstance(right, UnfinishedSelect)
-
- def both_selectors(left, right):
- return isinstance(left, UnfinishedSelect) \
- and isinstance(right, UnfinishedSelect)
- left = self.evaluateExpr(ast.left)
- right = self.evaluateExpr(ast.right)
- bop = ast.bop.text
- bad_pair = RASPTypeError(
- "Cannot apply and/or between selector and non-selector")
- if bop == "and":
- if has_sequence(left, right):
- if has_selector(left, right):
- raise bad_pair
- return tpland(left, right)
- elif has_selector(left, right):
- if not both_selectors(left, right):
- raise bad_pair
- return and_selects(left, right)
- else:
- return (left and right)
- elif bop == "or":
- if has_sequence(left, right):
- if has_selector(left, right):
- raise bad_pair
- return tplor(left, right)
- elif has_selector(left, right):
- if not both_selectors(left, right):
- raise bad_pair
- return or_selects(left, right)
- else:
- return (left or right)
- if has_selector(left, right):
- raise RASPTypeError("Cannot apply", bop, "to selector(s)")
- elif bop == "+":
- return left + right
- elif bop == "-":
- return left - right
- elif bop == "*":
- return left * right
- elif bop == "/":
- return left/right
- elif bop == "^":
- return pow(left, right)
- elif bop == '%':
- return left % right
- elif bop == "==":
- return left == right
- elif bop == "<=":
- return left <= right
- elif bop == ">=":
- return left >= right
- elif bop == "<":
- return left < right
- elif bop == ">":
- return left > right
- # more, like modulo and power and all the other operators, to come
- raise NotImplementedError
-
- def _evaluateStandalone(self, ast):
- if ast.anint:
- return int(ast.anint.text)
- if ast.afloat:
- return float(ast.afloat.text)
- if ast.astring:
- return ast.astring.text[1:-1]
- raise NotImplementedError
-
- def _evaluateTernaryExpr(self, ast):
- cond = self.evaluateExpr(ast.cond)
- if isinstance(cond, Unfinished):
- res1 = self.evaluateExpr(ast.res1)
- res2 = self.evaluateExpr(ast.res2)
- cond, res1, res2 = tuple(map(toseq, (cond, res1, res2)))
- return zipmap((cond, res1, res2), lambda c, r1, r2: r1
- if c else r2, name=res1.name+" if "+cond.name
- + " else " + res2.name).allow_suppressing_display()
- else:
- return self.evaluateExpr(ast.res1) if cond \
- else self.evaluateExpr(ast.res2)
- # lazy eval when cond is non-unfinished allows legal loops over
- # actual atoms
-
- def _evaluateAggregateExpr(self, ast):
- sel = self.evaluateExpr(ast.sel)
- seq = self.evaluateExpr(ast.seq)
- seq = toseq(seq) # just in case its an atom
- default = self.evaluateExpr(ast.default) if ast.default else None
-
- if not isinstance(sel, UnfinishedSelect):
- raise RASPTypeError("Expected selector, got:", strdesc(sel))
- if not isinstance(seq, UnfinishedSequence):
- raise RASPTypeError("Expected sequence, got:", strdesc(seq))
- if isinstance(default, Unfinished):
- raise RASPTypeError("Expected atom, got:", strdesc(default))
- return aggregate(sel, seq, default=default)
-
- def _evaluateZip(self, ast):
- list_exps = self._get_first_cont_list(ast.lists)
- lists = [self.evaluateExpr(e) for e in list_exps]
- if not lists:
- raise RASPTypeError("zip needs at least one list")
- for i, l in enumerate(lists):
- if not isinstance(l, list):
- raise RASPTypeError(
- "attempting to zip lists, but", i+1,
- "-th element is not list:", strdesc(l))
- n = len(lists[0])
- for i, l in enumerate(lists):
- if not len(l) == n:
- raise RASPTypeError("attempting to zip lists of length",
- n, ", but", i+1, "-th list has length",
- len(l))
- # keep everything lists, no tuples/lists mixing here, all the same to
- # rasp (no stuff like append etc)
- return [list(v) for v in zip(*lists)]
-
- def make_dict_key(self, ast):
- res = self.evaluateExpr(ast)
- if not isatom(res):
- raise RASPTypeError(
- "dictionary keys can only be atoms, but instead got:",
- strdesc(res))
- return res
-
- def _evaluateDict(self, ast):
- named_exprs_list = self._get_first_cont_list(ast.dictContents)
- return {self.make_dict_key(e.key): self.evaluateExpr(e.val)
- for e in named_exprs_list}
-
- def _evaluateList(self, ast):
- exprs_list = self._get_first_cont_list(ast.listContents)
- return [self.evaluateExpr(e) for e in exprs_list]
-
- def _evaluateApplication(self, ast, unf):
- input_vals = self._get_first_cont_list(ast.inputexprs)
- if not len(input_vals) == 1:
- raise ArgsError("evaluate unfinished", 1, len(input_vals))
- input_val = self.evaluateExpr(input_vals[0])
- if not isinstance(unf, Unfinished):
- raise RASPTypeError("Applying unfinished expects to apply",
- ENCODER_NAME, "or selector, got:",
- strdesc(unf))
- if not isinstance(input_val, Iterable):
- raise RASPTypeError(
- "Applying unfinished expects iterable input, got:",
- strdesc(input_val))
- res = unf.call(input_val)
- res.created_from_input = input_val
- return res
-
- def _evaluateRASPFunction(self, ast, raspfun):
- args_trees = self._get_first_cont_list(ast.inputexprs)
- args = tuple(self.evaluateExpr(t) for t in args_trees) + (self,)
- real_args = args[:-1]
- res = raspfun.call(*args)
- if isinstance(res, Unfinished):
- res.setname(
- raspfun.name+"("+" , ".join(strdesc(a, desc_cap=20)
- for a in real_args)+")")
- return res
-
- def _evaluateContains(self, ast):
- contained = self.evaluateExpr(ast.contained)
- container = self.evaluateExpr(ast.container)
- container_name = ast.container.var.text if ast.container.var \
- else str(container)
- if isinstance(contained, UnfinishedSequence):
- if not isinstance(container, list):
- raise RASPTypeError("\"["+ENCODER_NAME+"] in X\" expects X to"
- + "be list of atoms, but got non-list:",
- strdesc(container))
- for v in container:
- if not isatom(v):
- raise RASPTypeError("\"["+ENCODER_NAME+"] in X\" expects X"
- + "to be list of atoms, but got list "
- + "with values:", strdesc(container))
- return zipmap(contained, lambda c: c in container,
- name=contained.name + " in "
- + container_name).allow_suppressing_display()
- elif isatom(contained): # contained is now an atom
- if isinstance(container, list):
- return contained in container
- elif isinstance(container, UnfinishedSequence):
- indicator = zipmap(container, lambda v: int(v == contained))
- return aggregate(full_s, indicator) > 0
- else:
- raise RASPTypeError(
- "\"[atom] in X\" expects X to be list or " + ENCODER_NAME
- + ", but got:", strdesc(container))
- if isinstance(contained, UnfinishedSelect) or isinstance(contained,
- RASPFunction):
- obj_name = "select" if isinstance(
- contained, UnfinishedSelect) else "function"
- raise RASPTypeError("don't check if", obj_name,
- "is contained in list/dict: unless exact same "
- + "instance, unable to check equivalence of",
- obj_name + "s")
- else:
- raise RASPTypeError("\"A in X\" expects A to be",
- ENCODER_NAME, "or atom, but got A:",
- strdesc(contained))
-
- def _evaluateLen(self, ast):
- singleList = self.evaluateExpr(ast.singleList)
- if not isinstance(singleList, list) or isinstance(singleList, dict):
- raise RASPTypeError(
- "attempting to compute length of non-list:",
- strdesc(singleList))
- return len(singleList)
-
- def evaluateExprsList(self, ast):
- exprsList = self._get_first_cont_list(ast)
- return [self.evaluateExpr(v) for v in exprsList]
-
- def _test_res(self, res):
- if isinstance(res, Unfinished):
- def succeeds_with(exampe):
- try:
- res.call(example, just_pass_exception_up=True)
- except Exception:
- return False
- else:
- return True
- succeeds_with_backup = (self.backup_example is not None) and \
- succeeds_with(self.backup_example)
- if succeeds_with_backup:
- return
- succeeds_with_main = succeeds_with(self.sequence_running_example)
- if succeeds_with_main:
- return
- example = self.sequence_running_example if self.backup_example \
- is None else self.backup_example
- res.call(example, just_pass_exception_up=True)
-
- def evaluateExpr(self, ast, from_top=False, get_name=False):
- def format_return(res, resname="out",
- is_application_of_unfinished=False):
- ast.evaled_value = res
- # run a quick test of the result (by attempting to evaluate it on
- # an example) to make sure there hasn't been some weird type
- # problem, so it shouts even before someone actively tries to
- # evaluate it
- self._test_res(res)
-
- if is_application_of_unfinished:
- return JustVal(res)
- else:
- self.env.set_out(res)
- if from_top or get_name:
- # this is when an expression has been evaled
- return NamedVal(resname, res)
- else:
- return res
- if ast.bracketed: # in parentheses - get out of them
- return self.evaluateExpr(ast.bracketed, from_top=from_top)
- if ast.var: # calling single variable
- varname = ast.var.text
- return format_return(self.env.get_variable(varname),
- resname=varname)
- if ast.standalone:
- return format_return(self._evaluateStandalone(ast.standalone))
- if ast.bop:
- return format_return(self._evaluateBinaryExpr(ast))
- if ast.uop:
- return format_return(self._evaluateUnaryExpr(ast))
- if ast.cond:
- return format_return(self._evaluateTernaryExpr(ast))
- if ast.aggregate:
- return format_return(self._evaluateAggregateExpr(ast.aggregate))
- if ast.unfORfun:
-
- # before evaluating the unfORfun expression,
- # consider that it may be an unf that would not work
- # with the current running example, and allow that it may have
- # been sent in with an example for which it will work
- prev_backup = self.backup_example
- input_vals = self._get_first_cont_list(ast.inputexprs)
- if len(input_vals) == 1:
- self.backup_example = self.evaluateExpr(input_vals[0])
-
- unfORfun = self.evaluateExpr(ast.unfORfun)
-
- self.backup_example = prev_backup
-
- if isinstance(unfORfun, Unfinished):
- return format_return(self._evaluateApplication(ast, unfORfun),
- is_application_of_unfinished=True)
- elif isinstance(unfORfun, RASPFunction):
- return format_return(self._evaluateRASPFunction(ast, unfORfun))
- if ast.selop:
- return format_return(self._evaluateSelectExpr(ast))
- if ast.aList():
- return format_return(self._evaluateList(ast.aList()))
- if ast.aDict():
- return format_return(self._evaluateDict(ast.aDict()))
- if ast.indexable: # indexing into a list, dict, or s-op
- return format_return(self._evaluateIndexing(ast))
- if ast.rangevals:
- return format_return(self._evaluateRange(ast))
- if ast.listcomp:
- return format_return(self._evaluateListComp(ast))
- if ast.dictcomp:
- return format_return(self._evaluateDictComp(ast))
- if ast.container:
- return format_return(self._evaluateContains(ast))
- if ast.lists:
- return format_return(self._evaluateZip(ast))
- if ast.singleList:
- return format_return(self._evaluateLen(ast))
- raise NotImplementedError
+ def __init__(self, env, repl):
+ self.env = env
+ self.sequence_running_example = repl.sequence_running_example
+ self.backup_example = None
+ # allows evaluating something that maybe doesn't necessarily work with
+ # the main running example, but we just want to see what happens on
+ # it - e.g. so we can do draw(tokens_int+1,[1,2]) without error even
+ # while the main example is still "hello"
+ self.repl = repl
+
+ def evaluate(self, ast):
+ if ast.expr():
+ return self.evaluateExpr(ast.expr(), from_top=True)
+ if ast.assign():
+ return self.assign(ast.assign())
+ if ast.funcDef():
+ return self.funcDef(ast.funcDef())
+ if ast.draw():
+ return self.draw(ast.draw())
+ if ast.forLoop():
+ return self.forLoop(ast.forLoop())
+ if ast.loadFile():
+ return self.repl.loadFile(ast.loadFile(), self.env)
+
+ # more to come
+ raise NotImplementedError
+
+ def draw(self, ast):
+ # TODO: make at least some rudimentary comparisons of selectors somehow
+ # to merge heads idk?????? maybe keep trace of operations used to
+ # create them and those with exact same parent s-ops and operations
+ # can get in? would still find eg select(0,0,==) and select(1,1,==)
+ # different, but its better than nothing at all
+ example = self.evaluateExpr(
+ ast.inputseq) if ast.inputseq else self.sequence_running_example
+ prev_backup = self.backup_example
+ self.backup_example = example
+ unf = self.evaluateExpr(ast.unf)
+ if not isinstance(unf, UnfinishedSequence):
+ raise RASPTypeError("draw expects unfinished sequence, got:", unf)
+ unf.draw_comp_flow(example)
+ res = unf.call(example)
+ res.created_from_input = example
+ self.backup_example = prev_backup
+ return JustVal(res)
+
+ def assign(self, ast):
+ def set_val_and_name(val, name):
+ self.env.set_variable(name, val)
+ if isinstance(val, Unfinished):
+ val.setname(name) # completely irrelevant really for the REPL,
+ # but will help maintain sanity when printing computation flows
+ return NamedVal(name, val)
+
+ varnames = self._names_list(ast.var)
+ values = self.evaluateExprsList(ast.val)
+ if len(values) == 1:
+ values = values[0]
+
+ if len(varnames) == 1:
+ return set_val_and_name(values, varnames[0])
+ else:
+ if not len(varnames) == len(values):
+ raise RASPTypeError("expected", len(
+ varnames), "values, but got:", len(values))
+ reslist = []
+ for v, name in zip(values, varnames):
+ reslist.append(set_val_and_name(v, name))
+ return NamedValList(reslist)
+
+ def _names_list(self, ast):
+ idsList = self._get_first_cont_list(ast)
+ return [i.text for i in idsList]
+
+ def _set_iterator_and_vals(self, iterator_names, iterator_vals):
+ if len(iterator_names) == 1:
+ self.env.set_variable(iterator_names[0], iterator_vals)
+ elif (isinstance(iterator_vals, Iterable) and
+ (len(iterator_vals) == len(iterator_names))):
+ for n, v in zip(iterator_names, iterator_vals):
+ self.env.set_variable(n, v)
+ else:
+ if not isinstance(iterator_vals, Iterable):
+ raise RASPTypeError(
+ "iterating with multiple iterator names, but got single" +
+ " iterator value:", iterator_vals)
+ else:
+ # should work out by logic of last failed elif
+ errnote = "something wrong with Evaluator logic"
+ assert not (len(iterator_vals) == len(iterator_names)), errnote
+ raise RASPTypeError("iterating with", len(iterator_names),
+ "names but got", len(iterator_vals),
+ "values (", iterator_vals, ")")
+
+ def _evaluateDictComp(self, ast):
+ ast = ast.dictcomp
+ d = self.evaluateExpr(ast.iterable)
+ if not (isinstance(d, list) or isinstance(d, dict)):
+ raise RASPTypeError(
+ "dict comprehension should have got a list or dict to loop " +
+ "over, but got:", d)
+ res = {}
+ iterator_names = self._names_list(ast.iterator)
+ for vals in d:
+ orig_env = self.env
+ self.env = self.env.make_nested()
+ self._set_iterator_and_vals(iterator_names, vals)
+ key = self.make_dict_key(ast.key)
+ res[key] = self.evaluateExpr(ast.val)
+ self.env = orig_env
+ return res
+
+ def _evaluateListComp(self, ast):
+ ast = ast.listcomp
+ ll = self.evaluateExpr(ast.iterable)
+ if not (isinstance(ll, list) or isinstance(ll, dict)):
+ raise RASPTypeError(
+ "list comprehension should have got a list or dict to loop " +
+ "over, but got:", ll)
+ res = []
+ iterator_names = self._names_list(ast.iterator)
+ for vals in ll:
+ orig_env = self.env
+ self.env = self.env.make_nested()
+ # sets inside the now-nested env -don't want to keep
+ # the internal iterators after finishing this list comp
+ self._set_iterator_and_vals(iterator_names, vals)
+ res.append(self.evaluateExpr(ast.val))
+ self.env = orig_env
+ return res
+
+ def forLoop(self, ast):
+ iterator_names = self._names_list(ast.iterator)
+ iterable = self.evaluateExpr(ast.iterable)
+ if not (isinstance(iterable, list) or isinstance(iterable, dict)):
+ raise RASPTypeError(
+ "for loop needs to iterate over a list or dict, but got:",
+ iterable)
+ statements = self._get_first_cont_list(ast.mainbody)
+ for vals in iterable:
+ self._set_iterator_and_vals(iterator_names, vals)
+ for s in statements:
+ self.evaluate(s)
+ return JustVal(None)
+
+ def _get_first_cont_list(self, ast):
+ res = []
+ while ast:
+ if ast.first:
+ res.append(ast.first)
+ # sometimes there's no first cause it's just eating a comment
+ ast = ast.cont
+ return res
+
+ def funcDef(self, ast):
+ funcname = ast.name.text
+ argname_trees = self._get_first_cont_list(ast.arguments)
+ argnames = [a.text for a in argname_trees]
+ statement_trees = self._get_first_cont_list(ast.mainbody)
+ returnexpr = ast.retstatement.res
+ res = RASPFunction(funcname, self.env, argnames,
+ statement_trees, returnexpr, self.env.name)
+ self.env.set_variable(funcname, res)
+ return NamedVal(funcname, res)
+
+ def _evaluateUnaryExpr(self, ast):
+ uexpr = self.evaluateExpr(ast.uexpr)
+ uop = ast.uop.text
+ if uop == "not":
+ if isinstance(uexpr, UnfinishedSequence):
+ return tplnot(uexpr)
+ elif isinstance(uexpr, UnfinishedSelect):
+ return not_select(uexpr)
+ else:
+ return not uexpr
+ if uop == "-":
+ return -uexpr
+ if uop == "+":
+ return +uexpr
+ if uop == "round":
+ return round(uexpr)
+ if uop == "indicator":
+ if isinstance(uexpr, UnfinishedSequence):
+ name = "I("+uexpr.name+")"
+ zipmapped = zipmap(uexpr, lambda a: 1 if a else 0, name=name)
+ return zipmapped.allow_suppressing_display()
+ # naming res makes interpreter think it is important, i.e.,
+ # must always be displayed. but here it has only been named for
+ # clarity, so correct it using .allow_suppressing_display()
+
+ raise RASPTypeError(
+ "indicator operator expects " + ENCODER_NAME + ", got:", uexpr)
+ raise NotImplementedError
+
+ def _evaluateRange(self, ast):
+ valsList = self.evaluateExprsList(ast.rangevals)
+ if not len(valsList) in [1, 2, 3]:
+ raise RASPTypeError(
+ "wrong number of inputs to range, expected: 1, 2, or 3, got:",
+ len(valsList))
+ for v in valsList:
+ if not isinstance(v, int):
+ raise RASPTypeError(
+ "range expects all integer inputs, but got:",
+ strdesc(valsList))
+ return list(range(*valsList))
+
+ def _index_into_dict(self, d, index):
+ def invalid_key_error(i):
+ return RASPTypeError(
+ f"index into dict has to be {ENCODER_NAME} or atom" +
+ " (i.e., string, int, float, bool), got:", strdesc(i))
+
+ def missing_key_error(i):
+ return RASPValueError("index [", strdesc(i), "] not in dict.")
+
+ dname, indexname = d.name, index.name
+ d, index = d.val, index.val
+
+ if isinstance(index, UnfinishedSequence):
+ d = deepcopy(d)
+
+ def apply_d(i):
+ if i not in d:
+ raise missing_key_error(i)
+ return d[i]
+
+ name = f"{dname}[{indexname}]"
+ return zipmap((index,), apply_d, name=name)
+ elif not isatom(index):
+ raise invalid_key_error(index)
+ if index not in d:
+ raise missing_key_error(index)
+ else:
+ return d[index]
+
+ def _index_into_list_or_str(self, ll, index):
+ lname, indexname = ll.name, index.name
+ ll, index = ll.val, index.val
+ ltype = "list" if isinstance(ll, list) else "string"
+
+ def invalid_key_error(i):
+ return RASPTypeError(f"index into {ltype} has to be",
+ f"{ENCODER_NAME} or integer, got:",
+ strdesc(index))
+
+ def check_and_raise_key_error(i):
+ if i >= len(ll) or (-i) > len(ll):
+ raise RASPValueError("index", index, "out of range for", ltype,
+ "of length", len(ll))
+
+ if isinstance(index, UnfinishedSequence):
+ ll = deepcopy(ll)
+
+ def apply_l(i):
+ check_and_raise_key_error(i)
+ return ll[i]
+
+ name = f"{lname}[{indexname}]"
+ return zipmap((index,), apply_l, name=name)
+ elif not isinstance(index, int):
+ raise invalid_key_error(index)
+ check_and_raise_key_error(index)
+ return ll[index]
+
+ def _index_into_sequence(self, s, index):
+ sname, indexname = s.name, index.name
+ s, index = s.val, index.val
+ if isinstance(index, int):
+ if index >= 0:
+ sel = select(toseq(index), indices, lambda q,
+ k: q == k, name="load from "+str(index))
+ else:
+ length = self.env.get_variable("length")
+ real_index = length + index
+ real_index.setname(length.name+str(index))
+ sel = select(real_index, indices, lambda q,
+ k: q == k, name="load from "+str(index))
+ agg = aggregate(sel, s, name=s.name+"["+str(index)+"]")
+ return agg.allow_suppressing_display()
+ else:
+ raise RASPValueError(
+ "index into sequence has to be integer, got:", strdesc(index))
+
+ def _evaluateIndexing(self, ast):
+ indexable = self.evaluateExpr(ast.indexable, get_name=True)
+ index = self.evaluateExpr(ast.index, get_name=True)
+
+ if isinstance(indexable.val, list) or isinstance(indexable.val, str):
+ return self._index_into_list_or_str(indexable, index)
+ elif isinstance(indexable.val, dict):
+ return self._index_into_dict(indexable, index)
+ elif isinstance(indexable.val, UnfinishedSequence):
+ return self._index_into_sequence(indexable, index)
+ else:
+ raise RASPTypeError("can only index into a list, dict, string," +
+ "or sequence, but instead got:",
+ strdesc(indexable.val))
+
+ def _evaluateSelectExpr(self, ast):
+ key = self.evaluateExpr(ast.key)
+ query = self.evaluateExpr(ast.query)
+ sop = ast.selop.text
+ key = toseq(key) # in case got an atom in one of these,
+ query = toseq(query) # e.g. selecting 0th index: indices @= 0
+ if sop == "<":
+ return select(query, key, lambda q, k: q > k)
+ if sop == ">":
+ return select(query, key, lambda q, k: q < k)
+ if sop == "==":
+ return select(query, key, lambda q, k: q == k)
+ if sop == "!=":
+ return select(query, key, lambda q, k: not (q == k))
+ if sop == "<=":
+ return select(query, key, lambda q, k: q >= k)
+ if sop == ">=":
+ return select(query, key, lambda q, k: q <= k)
+
+ def _evaluateBinaryExpr(self, ast):
+ def has_sequence(left, right):
+ return isinstance(left, UnfinishedSequence) \
+ or isinstance(right, UnfinishedSequence)
+
+ def has_selector(left, right):
+ return isinstance(left, UnfinishedSelect) \
+ or isinstance(right, UnfinishedSelect)
+
+ def both_selectors(left, right):
+ return isinstance(left, UnfinishedSelect) \
+ and isinstance(right, UnfinishedSelect)
+ left = self.evaluateExpr(ast.left)
+ right = self.evaluateExpr(ast.right)
+ bop = ast.bop.text
+ bad_pair = RASPTypeError(
+ "Cannot apply and/or between selector and non-selector")
+ if bop == "and":
+ if has_sequence(left, right):
+ if has_selector(left, right):
+ raise bad_pair
+ return tpland(left, right)
+ elif has_selector(left, right):
+ if not both_selectors(left, right):
+ raise bad_pair
+ return and_selects(left, right)
+ else:
+ return (left and right)
+ elif bop == "or":
+ if has_sequence(left, right):
+ if has_selector(left, right):
+ raise bad_pair
+ return tplor(left, right)
+ elif has_selector(left, right):
+ if not both_selectors(left, right):
+ raise bad_pair
+ return or_selects(left, right)
+ else:
+ return (left or right)
+ if has_selector(left, right):
+ raise RASPTypeError("Cannot apply", bop, "to selector(s)")
+ elif bop == "+":
+ return left + right
+ elif bop == "-":
+ return left - right
+ elif bop == "*":
+ return left * right
+ elif bop == "/":
+ return left/right
+ elif bop == "^":
+ return pow(left, right)
+ elif bop == '%':
+ return left % right
+ elif bop == "==":
+ return left == right
+ elif bop == "<=":
+ return left <= right
+ elif bop == ">=":
+ return left >= right
+ elif bop == "<":
+ return left < right
+ elif bop == ">":
+ return left > right
+ # more, like modulo and power and all the other operators, to come
+ raise NotImplementedError
+
+ def _evaluateStandalone(self, ast):
+ if ast.anint:
+ return int(ast.anint.text)
+ if ast.afloat:
+ return float(ast.afloat.text)
+ if ast.astring:
+ return ast.astring.text[1: -1]
+ raise NotImplementedError
+
+ def _evaluateTernaryExpr(self, ast):
+ cond = self.evaluateExpr(ast.cond)
+ if isinstance(cond, Unfinished):
+ res1 = self.evaluateExpr(ast.res1)
+ res2 = self.evaluateExpr(ast.res2)
+ cond, res1, res2 = tuple(map(toseq, (cond, res1, res2)))
+ return zipmap((cond, res1, res2), lambda c, r1, r2: r1
+ if c else r2, name=res1.name + " if " + cond.name +
+ " else " + res2.name).allow_suppressing_display()
+ else:
+ return self.evaluateExpr(ast.res1) if cond \
+ else self.evaluateExpr(ast.res2)
+ # lazy eval when cond is non-unfinished allows legal loops over
+ # actual atoms
+
+ def _evaluateAggregateExpr(self, ast):
+ sel = self.evaluateExpr(ast.sel)
+ seq = self.evaluateExpr(ast.seq)
+ seq = toseq(seq) # just in case its an atom
+ default = self.evaluateExpr(ast.default) if ast.default else None
+
+ if not isinstance(sel, UnfinishedSelect):
+ raise RASPTypeError("Expected selector, got:", strdesc(sel))
+ if not isinstance(seq, UnfinishedSequence):
+ raise RASPTypeError("Expected sequence, got:", strdesc(seq))
+ if isinstance(default, Unfinished):
+ raise RASPTypeError("Expected atom, got:", strdesc(default))
+ return aggregate(sel, seq, default=default)
+
+ def _evaluateZip(self, ast):
+ list_exps = self._get_first_cont_list(ast.lists)
+ lists = [self.evaluateExpr(e) for e in list_exps]
+ if not lists:
+ raise RASPTypeError("zip needs at least one list")
+ for i, l in enumerate(lists):
+ if not isinstance(l, list):
+ raise RASPTypeError(
+ "attempting to zip lists, but", i+1,
+ "-th element is not list:", strdesc(l))
+ n = len(lists[0])
+ for i, l in enumerate(lists):
+ if not len(l) == n:
+ raise RASPTypeError("attempting to zip lists of length",
+ n, ", but", i+1, "-th list has length",
+ len(l))
+ # keep everything lists, no tuples/lists mixing here, all the same to
+ # rasp (no stuff like append etc)
+ return [list(v) for v in zip(*lists)]
+
+ def make_dict_key(self, ast):
+ res = self.evaluateExpr(ast)
+ if not isatom(res):
+ raise RASPTypeError(
+ "dictionary keys can only be atoms, but instead got:",
+ strdesc(res))
+ return res
+
+ def _evaluateDict(self, ast):
+ named_exprs_list = self._get_first_cont_list(ast.dictContents)
+ return {self.make_dict_key(e.key): self.evaluateExpr(e.val)
+ for e in named_exprs_list}
+
+ def _evaluateList(self, ast):
+ exprs_list = self._get_first_cont_list(ast.listContents)
+ return [self.evaluateExpr(e) for e in exprs_list]
+
+ def _evaluateApplication(self, ast, unf):
+ input_vals = self._get_first_cont_list(ast.inputexprs)
+ if not len(input_vals) == 1:
+ raise ArgsError("evaluate unfinished", 1, len(input_vals))
+ input_val = self.evaluateExpr(input_vals[0])
+ if not isinstance(unf, Unfinished):
+ raise RASPTypeError("Applying unfinished expects to apply",
+ ENCODER_NAME, "or selector, got:",
+ strdesc(unf))
+ if not isinstance(input_val, Iterable):
+ raise RASPTypeError(
+ "Applying unfinished expects iterable input, got:",
+ strdesc(input_val))
+ res = unf.call(input_val)
+ res.created_from_input = input_val
+ return res
+
+ def _evaluateRASPFunction(self, ast, raspfun):
+ args_trees = self._get_first_cont_list(ast.inputexprs)
+ args = tuple(self.evaluateExpr(t) for t in args_trees) + (self,)
+ real_args = args[:-1]
+ res = raspfun.call(*args)
+ if isinstance(res, Unfinished):
+ res.setname(raspfun.name + "(" +
+ " , ".join(strdesc(a, desc_cap=20) for
+ a in real_args) + ")")
+ return res
+
+ def _evaluateContains(self, ast):
+ contained = self.evaluateExpr(ast.contained)
+ container = self.evaluateExpr(ast.container)
+ container_name = ast.container.var.text if ast.container.var \
+ else str(container)
+ if isinstance(contained, UnfinishedSequence):
+ if not isinstance(container, list):
+ raise RASPTypeError(f"\"[{ENCODER_NAME}] in X\" expects X " +
+ "to be list of atoms, but got non-list:",
+ strdesc(container))
+ for v in container:
+ if not isatom(v):
+ raise RASPTypeError("\"[" + ENCODER_NAME + "] in X\" " +
+ "expects X to be list of atoms, but " +
+ "got list with values:",
+ strdesc(container))
+ return zipmap(contained, lambda c: c in container,
+ name=contained.name + " in " +
+ container_name).allow_suppressing_display()
+ elif isatom(contained): # contained is now an atom
+ if isinstance(container, list):
+ return contained in container
+ elif isinstance(container, UnfinishedSequence):
+ indicator = zipmap(container, lambda v: int(v == contained))
+ return aggregate(full_s, indicator) > 0
+ else:
+ raise RASPTypeError(
+ f"\"[atom] in X\" expects X to be list or {ENCODER_NAME}" +
+ ", but got:", strdesc(container))
+ if isinstance(contained, UnfinishedSelect) or isinstance(contained,
+ RASPFunction):
+ obj_name = "select" if isinstance(
+ contained, UnfinishedSelect) else "function"
+ raise RASPTypeError(f"don't check if {obj_name} is contained" +
+ " in list/dict: unless exact same instance, " +
+ f"unable to check equivalence of {obj_name}s")
+ else:
+ raise RASPTypeError("\"A in X\" expects A to be",
+ ENCODER_NAME, "or atom, but got A:",
+ strdesc(contained))
+
+ def _evaluateLen(self, ast):
+ singleList = self.evaluateExpr(ast.singleList)
+ if not isinstance(singleList, list) or isinstance(singleList, dict):
+ raise RASPTypeError(
+ "attempting to compute length of non-list:",
+ strdesc(singleList))
+ return len(singleList)
+
+ def evaluateExprsList(self, ast):
+ exprsList = self._get_first_cont_list(ast)
+ return [self.evaluateExpr(v) for v in exprsList]
+
+ def _test_res(self, res):
+ if isinstance(res, Unfinished):
+ def succeeds_with(exampe):
+ try:
+ res.call(example, just_pass_exception_up=True)
+ except Exception:
+ return False
+ else:
+ return True
+ succeeds_with_backup = (self.backup_example is not None) and \
+ succeeds_with(self.backup_example)
+ if succeeds_with_backup:
+ return
+ succeeds_with_main = succeeds_with(self.sequence_running_example)
+ if succeeds_with_main:
+ return
+ example = self.sequence_running_example if self.backup_example \
+ is None else self.backup_example
+ res.call(example, just_pass_exception_up=True)
+
+ def evaluateExpr(self, ast, from_top=False, get_name=False):
+ def format_return(res, resname="out",
+ is_application_of_unfinished=False):
+ ast.evaled_value = res
+ # run a quick test of the result (by attempting to evaluate it on
+ # an example) to make sure there hasn't been some weird type
+ # problem, so it shouts even before someone actively tries to
+ # evaluate it
+ self._test_res(res)
+
+ if is_application_of_unfinished:
+ return JustVal(res)
+ else:
+ self.env.set_out(res)
+ if from_top or get_name:
+ # this is when an expression has been evaled
+ return NamedVal(resname, res)
+ else:
+ return res
+ if ast.bracketed: # in parentheses - get out of them
+ return self.evaluateExpr(ast.bracketed, from_top=from_top)
+ if ast.var: # calling single variable
+ varname = ast.var.text
+ return format_return(self.env.get_variable(varname),
+ resname=varname)
+ if ast.standalone:
+ return format_return(self._evaluateStandalone(ast.standalone))
+ if ast.bop:
+ return format_return(self._evaluateBinaryExpr(ast))
+ if ast.uop:
+ return format_return(self._evaluateUnaryExpr(ast))
+ if ast.cond:
+ return format_return(self._evaluateTernaryExpr(ast))
+ if ast.aggregate:
+ return format_return(self._evaluateAggregateExpr(ast.aggregate))
+ if ast.unfORfun:
+
+ # before evaluating the unfORfun expression,
+ # consider that it may be an unf that would not work
+ # with the current running example, and allow that it may have
+ # been sent in with an example for which it will work
+ prev_backup = self.backup_example
+ input_vals = self._get_first_cont_list(ast.inputexprs)
+ if len(input_vals) == 1:
+ self.backup_example = self.evaluateExpr(input_vals[0])
+
+ unfORfun = self.evaluateExpr(ast.unfORfun)
+
+ self.backup_example = prev_backup
+
+ if isinstance(unfORfun, Unfinished):
+ return format_return(self._evaluateApplication(ast, unfORfun),
+ is_application_of_unfinished=True)
+ elif isinstance(unfORfun, RASPFunction):
+ return format_return(self._evaluateRASPFunction(ast, unfORfun))
+ if ast.selop:
+ return format_return(self._evaluateSelectExpr(ast))
+ if ast.aList():
+ return format_return(self._evaluateList(ast.aList()))
+ if ast.aDict():
+ return format_return(self._evaluateDict(ast.aDict()))
+ if ast.indexable: # indexing into a list, dict, or s-op
+ return format_return(self._evaluateIndexing(ast))
+ if ast.rangevals:
+ return format_return(self._evaluateRange(ast))
+ if ast.listcomp:
+ return format_return(self._evaluateListComp(ast))
+ if ast.dictcomp:
+ return format_return(self._evaluateDictComp(ast))
+ if ast.container:
+ return format_return(self._evaluateContains(ast))
+ if ast.lists:
+ return format_return(self._evaluateZip(ast))
+ if ast.singleList:
+ return format_return(self._evaluateLen(ast))
+ raise NotImplementedError
# new ast getText function for expressions
def new_getText(self): # original getText function stored as self._getText
- if hasattr(self, "evaled_value") and isatom(self.evaled_value):
- return str(self.evaled_value)
- else:
- return self._getText()
+ if hasattr(self, "evaled_value") and isatom(self.evaled_value):
+ return str(self.evaled_value)
+ else:
+ return self._getText()
RASPParser.ExprContext._getText = RASPParser.ExprContext.getText
diff --git a/RASP_support/FunctionalSupport.py b/RASP_support/FunctionalSupport.py
index 0bfef3b..98247fa 100644
--- a/RASP_support/FunctionalSupport.py
+++ b/RASP_support/FunctionalSupport.py
@@ -24,27 +24,27 @@
class NextId:
- def __init__(self):
- self.i = 0
+ def __init__(self):
+ self.i = 0
- def get_next(self):
- self.i += 1
- return self.i
+ def get_next(self):
+ self.i += 1
+ return self.i
unique_id_maker = NextId()
def creation_order_id():
- return unique_id_maker.get_next()
+ return unique_id_maker.get_next()
class AlreadyPrintedTheException:
- def __init__(self):
- self.b = False
+ def __init__(self):
+ self.b = False
- def __bool__(self):
- return self.b
+ def __bool__(self):
+ return self.b
global_printed = AlreadyPrintedTheException()
@@ -53,518 +53,525 @@ def __bool__(self):
class Unfinished:
- def __init__(self, parents_tuple, parents2self, name=plain_unfinished_name,
- is_toplevel_input=False, min_poss_depth=-1):
- self.parents_tuple = parents_tuple
- self.parents2self = parents2self
- self.last_w = None
- self.last_res = None
- self.is_toplevel_input = is_toplevel_input
- self.setname(name if not self.is_toplevel_input else "input")
- self.creation_order_id = creation_order_id()
- self.min_poss_depth = min_poss_depth
- self._real_parents = None
- self._full_parents = None
- self._sorted_full_parents = None
-
- def setname(self, name, always_display_when_named=True):
- if name is not None:
- if len(name) > name_maxlen:
- if isinstance(self, UnfinishedSequence):
- name = plain_unfinished_sequence_name
- elif isinstance(self, UnfinishedSelect):
- name = plain_unfinished_select_name
- else:
- name = plain_unfinished_name
- self.name = name
- # if you set something's name, you probably want to see it
- self.always_display = always_display_when_named
- # return self to allow chaining with other calls and throwing straight
- # into a return statement etc
- return self
-
- def get_parents(self):
- if None is self._real_parents:
- real_parents_part1 = [
- p for p in self.parents_tuple if is_real_unfinished(p)]
- other_parents = [
- p for p in self.parents_tuple if not is_real_unfinished(p)]
- res = real_parents_part1
- for p in other_parents:
- # recursion: branch back through all the parents of the unf,
- # always stopping wherever hit something 'real' ie a select or
- # a sequence
- res += p.get_parents()
- # nothing is made from more than one select...
- assert len(
- [p for p in res if isinstance(p, UnfinishedSelect)]) <= 1
- self._real_parents = set(res)
- # in case someone messes with the list eg popping through it
- return copy(self._real_parents)
-
- def _flat_compute_full_parents(self):
- # TODO: take advantage of anywhere full_parents have already been
- # computed, tho, otherwise no point in doing the recursion ever
- explored = set()
- not_explored = set([self])
- while not_explored:
- p = not_explored.pop()
- if p in explored:
- # this may happen due to also adding things directly to
- # explored sometimes
- continue
- if None is not p._full_parents:
- # note that _full_parents include self
- explored.update(p._full_parents)
- else:
- new_parents = p.get_parents()
- explored.add(p)
- not_explored.update(new_parents)
- return explored
-
- def _recursive_compute_full_parents(self):
- res = self.get_parents() # get_parents returns a copy
- res.update([self]) # full parents include self
- for p in self.get_parents():
- res.update(p.get_full_parents(recurse=True, trusted=True))
- return res
-
- def _sort_full_parents(self):
- if None is self._sorted_full_parents:
- self._sorted_full_parents = sorted(
- self._full_parents, key=lambda unf: unf.creation_order_id)
-
- def get_full_parents(self, recurse=False, just_compute=False,
- trusted=False):
- # Note: full_parents include self
- if None is self._full_parents:
- if recurse:
- self._full_parents = self._recursive_compute_full_parents()
- else:
- self._full_parents = self._flat_compute_full_parents()
- # avoids recursion, and so avoids passing the max recursion
- # depth
-
- # but now having done that we would like to store the result
- # for all parents so we can take advantage of it in the future
- for p in self.get_sorted_full_parents():
- p.get_full_parents(recurse=True, just_compute=True)
- # have them all compute their full parents so they are
- # ready for the future, but only do this in sorted order,
- # so recursion is always shallow. (always gets shorted with
- # self._full_parents, which is being computed here for each
- # unfinished starting from the top of the computation
- # graph)
- if not just_compute:
- if trusted:
- # functions where you have checked they don't modify the
- # returned result can be marked as trusted and get the true
- # _full_parents
- return self._full_parents
- else:
- # otherwise they get a copy
- return copy(self._full_parents)
-
- def get_sorted_full_parents(self):
- # could have just made get_full_parents give a sorted result, but
- # wanted a function where name is already clear that result will be
- # sorted, to avoid weird bugs in future. (especially that being not
- # sorted will only affect performance, and possibly break recursion
- # depth)
-
- if None is self._sorted_full_parents:
- if None is self._full_parents:
- self.get_full_parents(just_compute=True)
- self._sort_full_parents()
- return copy(self._sorted_full_parents)
-
- def call(self, w, topcall=True, just_pass_exception_up=False):
- if (not isinstance(w, Iterable)) or (not w):
- raise RASPTypeError(
- "RASP sequences/selectors expect non-empty iterables, got: "
- + str(w))
- global_printed.b = False
- if w == self.last_w:
- return self.last_res # don't print same calculation multiple times
-
- else:
- if self.is_toplevel_input:
- res = w
- self.last_w, self.last_res = w, w
- else:
- try:
- if topcall:
- # before doing the main call, evaluate all parents
- # (in order of dependencies, attainable by using
- # creation_order_id attribute), this avoids a deep
- # recursion: every element that is evaluated only has
- # to go back as far as its own 'real' (i.e., s-op or
- # selector) parents to hit something that has already
- # been evaluated, and then those will not recurse
- # further back as they use memoization
- for unf in self.get_sorted_full_parents():
- # evaluate
- unf.call(w, topcall=False,
- just_pass_exception_up=just_pass_exception_up)
-
- j_p_e_u = just_pass_exception_up
- args = tuple(p.call(w,
- topcall=False,
- just_pass_exception_up=j_p_e_u)
- for p in self.parents_tuple)
- res = self.parents2self(*args)
- except Exception as e:
- if just_pass_exception_up:
- raise e
- if isinstance(e, RASPTypeError):
- raise e
- if not global_printed.b:
- seperator = "=" * 63
- print(colored(f"{seperator}\n{seperator}", error_color))
- error_msg = f"evaluation failed in: [ {self.name} ]" +\
- f" with exception:\n {e}"
- print(colored(error_msg, error_color))
- print(colored(seperator, error_color))
- print(colored("parent values are:", error_color))
- for p in self.parents_tuple:
- print(colored(
- f"=============\n{p.name}\n{p.last_res}",
- error_color))
- print(colored(f"{seperator}\n{seperator}", error_color))
- a, b, tb = sys.exc_info()
- tt = traceback.extract_tb(tb)
- last_call = max([i for i, t in enumerate(tt)
- if "in call" in str(t)])
- traceback_msg = \
- ''.join(traceback.format_list(tt[last_call+1:]))
- print(colored(traceback_msg, error_color))
-
- # traceback.print_exception(a,b,tb)
-
- global_printed.b = True
-
- if debug or not topcall:
- raise
- else:
- return "EVALUATION FAILURE"
-
- self.last_w, self.last_res = w, res
- return res
+ def __init__(self, parents_tuple, parents2self, name=plain_unfinished_name,
+ is_toplevel_input=False, min_poss_depth=-1):
+ self.parents_tuple = parents_tuple
+ self.parents2self = parents2self
+ self.last_w = None
+ self.last_res = None
+ self.is_toplevel_input = is_toplevel_input
+ self.setname(name if not self.is_toplevel_input else "input")
+ self.creation_order_id = creation_order_id()
+ self.min_poss_depth = min_poss_depth
+ self._real_parents = None
+ self._full_parents = None
+ self._sorted_full_parents = None
+
+ def setname(self, name, always_display_when_named=True):
+ if name is not None:
+ if len(name) > name_maxlen:
+ if isinstance(self, UnfinishedSequence):
+ name = plain_unfinished_sequence_name
+ elif isinstance(self, UnfinishedSelect):
+ name = plain_unfinished_select_name
+ else:
+ name = plain_unfinished_name
+ self.name = name
+ # if you set something's name, you probably want to see it
+ self.always_display = always_display_when_named
+ # return self to allow chaining with other calls and throwing straight
+ # into a return statement etc
+ return self
+
+ def get_parents(self):
+ if None is self._real_parents:
+ real_parents_part1 = [
+ p for p in self.parents_tuple if is_real_unfinished(p)]
+ other_parents = [
+ p for p in self.parents_tuple if not is_real_unfinished(p)]
+ res = real_parents_part1
+ for p in other_parents:
+ # recursion: branch back through all the parents of the unf,
+ # always stopping wherever hit something 'real' ie a select or
+ # a sequence
+ res += p.get_parents()
+ # nothing is made from more than one select...
+ assert len(
+ [p for p in res if isinstance(p, UnfinishedSelect)]) <= 1
+ self._real_parents = set(res)
+ # in case someone messes with the list eg popping through it
+ return copy(self._real_parents)
+
+ def _flat_compute_full_parents(self):
+ # TODO: take advantage of anywhere full_parents have already been
+ # computed, tho, otherwise no point in doing the recursion ever
+ explored = set()
+ not_explored = set([self])
+ while not_explored:
+ p = not_explored.pop()
+ if p in explored:
+ # this may happen due to also adding things directly to
+ # explored sometimes
+ continue
+ if None is not p._full_parents:
+ # note that _full_parents include self
+ explored.update(p._full_parents)
+ else:
+ new_parents = p.get_parents()
+ explored.add(p)
+ not_explored.update(new_parents)
+ return explored
+
+ def _recursive_compute_full_parents(self):
+ res = self.get_parents() # get_parents returns a copy
+ res.update([self]) # full parents include self
+ for p in self.get_parents():
+ res.update(p.get_full_parents(recurse=True, trusted=True))
+ return res
+
+ def _sort_full_parents(self):
+ if None is self._sorted_full_parents:
+ self._sorted_full_parents = sorted(
+ self._full_parents, key=lambda unf: unf.creation_order_id)
+
+ def get_full_parents(self, recurse=False, just_compute=False,
+ trusted=False):
+ # Note: full_parents include self
+ if None is self._full_parents:
+ if recurse:
+ self._full_parents = self._recursive_compute_full_parents()
+ else:
+ self._full_parents = self._flat_compute_full_parents()
+ # avoids recursion, and so avoids passing the max recursion
+ # depth
+
+ # but now having done that we would like to store the result
+ # for all parents so we can take advantage of it in the future
+ for p in self.get_sorted_full_parents():
+ p.get_full_parents(recurse=True, just_compute=True)
+ # have them all compute their full parents so they are
+ # ready for the future, but only do this in sorted order,
+ # so recursion is always shallow. (always gets shorted with
+ # self._full_parents, which is being computed here for each
+ # unfinished starting from the top of the computation
+ # graph)
+ if not just_compute:
+ if trusted:
+ # functions where you have checked they don't modify the
+ # returned result can be marked as trusted and get the true
+ # _full_parents
+ return self._full_parents
+ else:
+ # otherwise they get a copy
+ return copy(self._full_parents)
+
+ def get_sorted_full_parents(self):
+ # could have just made get_full_parents give a sorted result, but
+ # wanted a function where name is already clear that result will be
+ # sorted, to avoid weird bugs in future. (especially that being not
+ # sorted will only affect performance, and possibly break recursion
+ # depth)
+
+ if None is self._sorted_full_parents:
+ if None is self._full_parents:
+ self.get_full_parents(just_compute=True)
+ self._sort_full_parents()
+ return copy(self._sorted_full_parents)
+
+ def call(self, w, topcall=True, just_pass_exception_up=False):
+ if (not isinstance(w, Iterable)) or (not w):
+ raise RASPTypeError(
+ "RASP sequences/selectors expect non-empty iterables, got: " +
+ str(w))
+ global_printed.b = False
+ if w == self.last_w:
+ return self.last_res # don't print same calculation multiple times
+
+ else:
+ if self.is_toplevel_input:
+ res = w
+ self.last_w, self.last_res = w, w
+ else:
+ try:
+ j_p_e_u = just_pass_exception_up
+ if topcall:
+ # before doing the main call, evaluate all parents
+ # (in order of dependencies, attainable by using
+ # creation_order_id attribute), this avoids a deep
+ # recursion: every element that is evaluated only has
+ # to go back as far as its own 'real' (i.e., s-op or
+ # selector) parents to hit something that has already
+ # been evaluated, and then those will not recurse
+ # further back as they use memoization
+ for unf in self.get_sorted_full_parents():
+ # evaluate
+ unf.call(w, topcall=False,
+ just_pass_exception_up=j_p_e_u)
+ args = tuple(p.call(w, topcall=False,
+ just_pass_exception_up=j_p_e_u)
+ for p in self.parents_tuple)
+ res = self.parents2self(*args)
+ except Exception as e:
+ if just_pass_exception_up:
+ raise e
+ if isinstance(e, RASPTypeError):
+ raise e
+ if not global_printed.b:
+ seperator = "=" * 63
+ print(colored(f"{seperator}\n{seperator}",
+ error_color))
+ error_msg = f"evaluation failed in: [ {self.name} ]" +\
+ f" with exception:\n {e}"
+ print(colored(error_msg, error_color))
+ print(colored(seperator, error_color))
+ print(colored("parent values are:", error_color))
+ for p in self.parents_tuple:
+ print(colored(
+ f"=============\n{p.name}\n{p.last_res}",
+ error_color))
+ print(colored(f"{seperator}\n{seperator}",
+ error_color))
+ a, b, tb = sys.exc_info()
+ tt = traceback.extract_tb(tb)
+ last_call = max([i for i, t in enumerate(tt)
+ if "in call" in str(t)])
+ traceback_msg = \
+ ''.join(traceback.format_list(tt[last_call + 1:]))
+ print(colored(traceback_msg, error_color))
+
+ # traceback.print_exception(a,b,tb)
+
+ global_printed.b = True
+
+ if debug or not topcall:
+ raise
+ else:
+ return "EVALUATION FAILURE"
+
+ self.last_w, self.last_res = w, res
+ return res
class UnfinishedSequence(Unfinished):
- def __init__(self, parents_tuple, parents2self,
- name=plain_unfinished_sequence_name,
- elementwise_function=None, default=None, min_poss_depth=0,
- from_zipmap=False, output_index=-1,
- definitely_uses_identity_function=False):
- # min_poss_depth=0 starts all of the base sequences (eg indices) off
- # right.
-
- # might have got none from some default value, fix it before continuing
- # because later things eg DrawCompFlow will expect name to be str
- if name is None:
- name = plain_unfinished_sequence_name
- super(UnfinishedSequence, self).__init__(parents_tuple,
- parents2self, name=name,
- min_poss_depth=min_poss_depth)
- # can be inferred (by seeing if there are parent selects), but this is
- # simple enough. helpful for rendering comp flow visualisations
- self.from_zipmap = from_zipmap
- # useful for analysis later
- self.elementwise_function = elementwise_function
- self.output_index = output_index
- # useful for analysis later
- self.default = default
- self.definitely_uses_identity_function = \
- definitely_uses_identity_function
- self.never_display = False
- self._constant = False
-
- def __str__(self):
- id = str(self.creation_order_id)
- return "UnfinishedSequence object, name: " + self.name + " id: " + id
-
- def mark_as_constant(self):
- self._constant = True
- return self
-
- def is_constant(self):
- return self._constant
+ def __init__(self, parents_tuple, parents2self,
+ name=plain_unfinished_sequence_name,
+ elementwise_function=None, default=None, min_poss_depth=0,
+ from_zipmap=False, output_index=-1,
+ definitely_uses_identity_function=False):
+ # min_poss_depth=0 starts all of the base sequences (eg indices) off
+ # right.
+
+ # might have got none from some default value, fix it before continuing
+ # because later things eg DrawCompFlow will expect name to be str
+ if name is None:
+ name = plain_unfinished_sequence_name
+ super(UnfinishedSequence, self).__init__(parents_tuple,
+ parents2self, name=name,
+ min_poss_depth=min_poss_depth)
+ # can be inferred (by seeing if there are parent selects), but this is
+ # simple enough. helpful for rendering comp flow visualisations
+ self.from_zipmap = from_zipmap
+ # useful for analysis later
+ self.elementwise_function = elementwise_function
+ self.output_index = output_index
+ # useful for analysis later
+ self.default = default
+ self.definitely_uses_identity_function = \
+ definitely_uses_identity_function
+ self.never_display = False
+ self._constant = False
+
+ def __str__(self):
+ id = str(self.creation_order_id)
+ return "UnfinishedSequence object, name: " + self.name + " id: " + id
+
+ def mark_as_constant(self):
+ self._constant = True
+ return self
+
+ def is_constant(self):
+ return self._constant
class UnfinishedSelect(Unfinished):
- def __init__(self, parents_tuple, parents2self,
- name=plain_unfinished_select_name, compare_string=None,
- min_poss_depth=-1, q_vars=None, k_vars=None,
- orig_selector=None): # selects should be told their depth,
- # -1 will warn of problems properly
- if name is None: # as in unfinishedsequence, some other function might
- # have passed in a None somewhere
- name = plain_unfinished_select_name # so fix before a print goes
- # wrong
- super(UnfinishedSelect, self).__init__(parents_tuple,
- parents2self, name=name,
- min_poss_depth=min_poss_depth)
- self.compare_string = str(
- self.creation_order_id) if compare_string is None \
- else compare_string
- # they're not really optional i just dont want to add more mess to the
- # func
- assert None not in [q_vars, k_vars]
- self.q_vars = q_vars # don't actually need them, but useful for
- self.k_vars = k_vars # drawing comp flow
- # use compare string for comparison/uniqueness rather than overloading
- # __eq__ of unfinishedselect, to avoid breaking things in unknown
- # locations, and to be able to put selects in dictionaries and stuff
- # (overloading __eq__ makes an object unhasheable unless i guess you
- # overload the hash too?). need these comparisons for optimisations in
- # analysis eg if two selects are identical they can be same head
- self.orig_selector = orig_selector # for comfortable compositions of
- # selectors
-
- def __str__(self):
- id = str(self.creation_order_id)
- return "UnfinishedSelect object, name: " + self.name + " id: " + id
+ def __init__(self, parents_tuple, parents2self,
+ name=plain_unfinished_select_name, compare_string=None,
+ min_poss_depth=-1, q_vars=None, k_vars=None,
+ orig_selector=None): # selects should be told their depth,
+ # -1 will warn of problems properly
+ if name is None: # as in unfinishedsequence, some other function might
+ # have passed in a None somewhere
+ name = plain_unfinished_select_name # so fix before a print goes
+ # wrong
+ super(UnfinishedSelect, self).__init__(parents_tuple,
+ parents2self, name=name,
+ min_poss_depth=min_poss_depth)
+ self.compare_string = str(
+ self.creation_order_id) if compare_string is None \
+ else compare_string
+ # they're not really optional i just dont want to add more mess to the
+ # func
+ assert None not in [q_vars, k_vars]
+ self.q_vars = q_vars # don't actually need them, but useful for
+ self.k_vars = k_vars # drawing comp flow
+ # use compare string for comparison/uniqueness rather than overloading
+ # __eq__ of unfinishedselect, to avoid breaking things in unknown
+ # locations, and to be able to put selects in dictionaries and stuff
+ # (overloading __eq__ makes an object unhasheable unless i guess you
+ # overload the hash too?). need these comparisons for optimisations in
+ # analysis eg if two selects are identical they can be same head
+ self.orig_selector = orig_selector # for comfortable compositions of
+ # selectors
+
+ def __str__(self):
+ id = str(self.creation_order_id)
+ return "UnfinishedSelect object, name: " + self.name + " id: " + id
# as opposed to intermediate unfinisheds like tuples of sequences
def is_real_unfinished(unf):
- return isinstance(unf, UnfinishedSequence) \
- or isinstance(unf, UnfinishedSelect)
+ return isinstance(unf, UnfinishedSequence) \
+ or isinstance(unf, UnfinishedSelect)
# some tiny bit of sugar that fits here:
def is_sequence_of_unfinishedseqs(seqs):
- if not isinstance(seqs, Iterable):
- return False
- return False not in [isinstance(seq, UnfinishedSequence) for seq in seqs]
+ if not isinstance(seqs, Iterable):
+ return False
+ return False not in [isinstance(seq, UnfinishedSequence) for seq in seqs]
class BareBonesFunctionalSupportException(Exception):
- def __init__(self, m):
- Exception.__init__(self, m)
+ def __init__(self, m):
+ Exception.__init__(self, m)
def to_tuple_of_unfinishedseqs(seqs):
- if is_sequence_of_unfinishedseqs(seqs):
- return tuple(seqs)
- if isinstance(seqs, UnfinishedSequence):
- return (seqs,)
- print(colored(f"seqs: {seqs}", general_color))
- raise BareBonesFunctionalSupportException(
- "input to select/aggregate not an unfinished sequence or sequence of"
- + " unfinished sequences")
+ if is_sequence_of_unfinishedseqs(seqs):
+ return tuple(seqs)
+ if isinstance(seqs, UnfinishedSequence):
+ return (seqs,)
+ print(colored(f"seqs: {seqs}", general_color))
+ raise BareBonesFunctionalSupportException(
+ "input to select/aggregate not an unfinished sequence or sequence" +
+ "of unfinished sequences")
def tup2tup(*x):
- return tuple([*x])
+ return tuple([*x])
class UnfinishedSequencesTuple(Unfinished):
- def __init__(self, parents_tuple, parents2self=None):
- # sequence tuples only exist in here, user doesn't 'see' them. can have
- # lots of default values they're just a convenience for me
- if parents2self is None: # just sticking a bunch of unfinished
- # sequences together into one thing for reasons
- parents2self = tup2tup
- parents_tuple = to_tuple_of_unfinishedseqs(parents_tuple)
- assert is_sequence_of_unfinishedseqs(
- parents_tuple) and isinstance(parents_tuple, tuple)
- # else - probably creating several sequences at once from one aggregate
- super(UnfinishedSequencesTuple, self).__init__(
- parents_tuple, parents2self, name="plain unfinished tuple")
-
- def __add__(self, other):
- assert isinstance(other, UnfinishedSequencesTuple)
- assert self.parents2self is tup2tup
- assert other.parents2self is tup2tup
- return UnfinishedSequencesTuple(self.parents_tuple+other.parents_tuple)
+ def __init__(self, parents_tuple, parents2self=None):
+ # sequence tuples only exist in here, user doesn't 'see' them. can have
+ # lots of default values they're just a convenience for me
+ if parents2self is None: # just sticking a bunch of unfinished
+ # sequences together into one thing for reasons
+ parents2self = tup2tup
+ parents_tuple = to_tuple_of_unfinishedseqs(parents_tuple)
+ assert is_sequence_of_unfinishedseqs(
+ parents_tuple) and isinstance(parents_tuple, tuple)
+ # else - probably creating several sequences at once from one aggregate
+ super(UnfinishedSequencesTuple, self).__init__(
+ parents_tuple, parents2self, name="plain unfinished tuple")
+
+ def __add__(self, other):
+ assert isinstance(other, UnfinishedSequencesTuple)
+ assert self.parents2self is tup2tup
+ assert other.parents2self is tup2tup
+ return UnfinishedSequencesTuple(self.parents_tuple +
+ other.parents_tuple)
_input = Unfinished((), None, is_toplevel_input=True)
# and now, the actual exposed functions
indices = UnfinishedSequence((_input,), lambda w: Sequence(
- list(range(len(w)))), name=plain_indices)
+ list(range(len(w)))), name=plain_indices)
tokens_str = UnfinishedSequence((_input,), lambda w: Sequence(
- list(map(str, w))), name=plain_tokens+"_str")
+ list(map(str, w))), name=plain_tokens + "_str")
tokens_int = UnfinishedSequence((_input,), lambda w: Sequence(
- list(map(int, w))), name=plain_tokens+"_int")
+ list(map(int, w))), name=plain_tokens + "_int")
tokens_float = UnfinishedSequence((_input,), lambda w: Sequence(
- list(map(float, w))), name=plain_tokens+"_float")
+ list(map(float, w))), name=plain_tokens + "_float")
tokens_bool = UnfinishedSequence((_input,), lambda w: Sequence(
- list(map(bool, w))), name=plain_tokens+"_bool")
+ list(map(bool, w))), name=plain_tokens + "_bool")
tokens_asis = UnfinishedSequence(
- (_input,), lambda w: Sequence(w), name=plain_tokens+"_asis")
+ (_input,), lambda w: Sequence(w), name=plain_tokens + "_asis")
base_tokens = [tokens_str, tokens_int, tokens_float, tokens_bool, tokens_asis]
def _min_poss_depth(unfs):
- if isinstance(unfs, Unfinished): # got single unfinished and not iterable
- # of them
- unfs = [unfs]
- # max b/c cant go less deep than deepest
- return max([u.min_poss_depth for u in unfs]+[0])
- # add that 0 thing so list is never empty and max complains.
+ if isinstance(unfs, Unfinished): # got single unfinished and not iterable
+ # of them
+ unfs = [unfs]
+ # max b/c cant go less deep than deepest
+ return max([u.min_poss_depth for u in unfs] + [0])
+ # add that 0 thing so list is never empty and max complains.
def tupleise(v):
- if isinstance(v, tuple) or isinstance(v, list):
- return tuple(v)
- return (v,)
+ if isinstance(v, tuple) or isinstance(v, list):
+ return tuple(v)
+ return (v,)
def select(q_vars, k_vars, selector, name=None, compare_string=None):
- if None is name:
- name = "plain select"
- # potentially here check the qvars all reference the same input sequence as
- # each other and same for the kvars, technically dont *have* to but is
- # helpful for the user so consider maybe adding a tiny bit of mess here
- # (including markings inside sequences and selectors so they know which
- # index they're gathering to and from) to allow it
-
- # we're ok with getting a single q or k var, not in a tuple,
- # but important to fix it before '+' on two UnfinishedSequences
- # (as opposed to two tuples) sends everything sideways
- q_vars = tupleise(q_vars)
- k_vars = tupleise(k_vars)
-
- # attn layer is one after values it needs to be calculated
- new_depth = _min_poss_depth(q_vars+k_vars)+1
- res = UnfinishedSelect((_input, # need input seq length to create select
- # of correct size
- UnfinishedSequencesTuple(q_vars),
- UnfinishedSequencesTuple(k_vars)),
- lambda input_seq, qv, kv: _select(
- len(input_seq), qv, kv, selector),
- name=name, compare_string=compare_string,
- min_poss_depth=new_depth, q_vars=q_vars,
- k_vars=k_vars, orig_selector=selector)
- return res
+ if None is name:
+ name = "plain select"
+ # potentially here check the qvars all reference the same input sequence as
+ # each other and same for the kvars, technically dont *have* to but is
+ # helpful for the user so consider maybe adding a tiny bit of mess here
+ # (including markings inside sequences and selectors so they know which
+ # index they're gathering to and from) to allow it
+
+ # we're ok with getting a single q or k var, not in a tuple,
+ # but important to fix it before '+' on two UnfinishedSequences
+ # (as opposed to two tuples) sends everything sideways
+ q_vars = tupleise(q_vars)
+ k_vars = tupleise(k_vars)
+
+ # attn layer is one after values it needs to be calculated
+ new_depth = _min_poss_depth(q_vars + k_vars) + 1
+ res = UnfinishedSelect((_input, # need input seq length to create select
+ # of correct size
+ UnfinishedSequencesTuple(q_vars),
+ UnfinishedSequencesTuple(k_vars)),
+ lambda input_seq, qv, kv: _select(
+ len(input_seq), qv, kv, selector),
+ name=name, compare_string=compare_string,
+ min_poss_depth=new_depth, q_vars=q_vars,
+ k_vars=k_vars, orig_selector=selector)
+ return res
def _compose_selects(select1, select2, compose_op=None, name=None,
- compare_string=None):
- nq1 = len(select1.q_vars)
- nq2 = len(select2.q_vars)+nq1
- nk1 = len(select1.k_vars)+nq2
-
- def new_selector(*qqkk):
- q1 = qqkk[:nq1]
- q2 = qqkk[nq1:nq2]
- k1 = qqkk[nq2:nk1]
- k2 = qqkk[nk1:]
- return compose_op(select1.orig_selector(*q1, *k1),
- select2.orig_selector(*q2, *k2))
- return select(select1.q_vars+select2.q_vars,
- select1.k_vars+select2.k_vars,
- new_selector, name=name, compare_string=compare_string)
+ compare_string=None):
+ nq1 = len(select1.q_vars)
+ nq2 = len(select2.q_vars) + nq1
+ nk1 = len(select1.k_vars) + nq2
+
+ def new_selector(*qqkk):
+ q1 = qqkk[:nq1]
+ q2 = qqkk[nq1:nq2]
+ k1 = qqkk[nq2:nk1]
+ k2 = qqkk[nk1:]
+ return compose_op(select1.orig_selector(*q1, *k1),
+ select2.orig_selector(*q2, *k2))
+ return select(select1.q_vars + select2.q_vars,
+ select1.k_vars + select2.k_vars,
+ new_selector, name=name, compare_string=compare_string)
def _compose_select(select1, compose_op=None, name=None, compare_string=None):
- def new_selector(*qk):
- return compose_op(select1.orig_selector(*qk))
- return select(select1.q_vars,
- select1.k_vars,
- new_selector, name=name, compare_string=compare_string)
+ def new_selector(*qk):
+ return compose_op(select1.orig_selector(*qk))
+ return select(select1.q_vars,
+ select1.k_vars,
+ new_selector, name=name, compare_string=compare_string)
def not_select(select, name=None, compare_string=None):
- return _compose_select(select, lambda a: not a,
- name=name, compare_string=compare_string)
+ return _compose_select(select, lambda a: not a,
+ name=name, compare_string=compare_string)
def and_selects(select1, select2, name=None, compare_string=None):
- return _compose_selects(select1, select2, lambda a, b: a and b,
- name=name, compare_string=compare_string)
+ return _compose_selects(select1, select2, lambda a, b: a and b,
+ name=name, compare_string=compare_string)
def or_selects(select1, select2, name=None, compare_string=None):
- return _compose_selects(select1, select2, lambda a, b: a or b,
- name=name, compare_string=compare_string)
+ return _compose_selects(select1, select2, lambda a, b: a or b,
+ name=name, compare_string=compare_string)
def format_output(parents_tuple, parents2res, name, elementwise_function=None,
- default=None, min_poss_depth=0, from_zipmap=False,
- definitely_uses_identity_function=False):
- def_uses = definitely_uses_identity_function
- return UnfinishedSequence(parents_tuple, parents2res,
- elementwise_function=elementwise_function,
- default=default, name=name,
- min_poss_depth=min_poss_depth,
- from_zipmap=from_zipmap,
- definitely_uses_identity_function=def_uses)
+ default=None, min_poss_depth=0, from_zipmap=False,
+ definitely_uses_identity_function=False):
+ def_uses = definitely_uses_identity_function
+ return UnfinishedSequence(parents_tuple, parents2res,
+ elementwise_function=elementwise_function,
+ default=default, name=name,
+ min_poss_depth=min_poss_depth,
+ from_zipmap=from_zipmap,
+ definitely_uses_identity_function=def_uses)
def get_identity_function(num_params):
- def identity1(a):
- return a
+ def identity1(a):
+ return a
- def identityx(*a):
- return a
- return identity1 if num_params == 1 else identityx
+ def identityx(*a):
+ return a
+ return identity1 if num_params == 1 else identityx
def zipmap(sequences_tuple, elementwise_function,
- name=plain_unfinished_sequence_name):
- sequences_tuple = tupleise(sequences_tuple)
- unfinished_parents_tuple = UnfinishedSequencesTuple(
- sequences_tuple) # this also takes care of turning the
- # value in sequences_tuple to indeed a tuple of sequences and not eg a
- # single sequence which will cause weird behaviour later
-
- parents_tuple = (_input, unfinished_parents_tuple)
- def parents2res(w, vt): return _zipmap(len(w), vt, elementwise_function)
- # feedforward doesn't increase layer
- min_poss_depth = _min_poss_depth(sequences_tuple)
- # new assumption, to be revised later: can do arbitrary zipmap even before
- # first feed-forward, i.e. in build up to first attention. truth is can do
- # 'simple' zipmap towards first attention (no xor, but yes things like
- # 'and' or 'indicator for ==' or whatever) based on initial linear
- # translation done for Q,K in attention (not deep enough for xor, but deep
- # enough for simple stuff) alongside use of initial embedding. honestly
- # literally can just put everything in initial embedding if need it so bad
- # its the first layer and its zipmap its only a function of the token and
- # indices, so long as its not computing any weird combination between them
- # you can do it in the embedding
- # if len(sequences_tuple)>0:
- # min_poss_depth = max(min_poss_depth,1) # except for the very specific
- # # case where it is the very first thing to be done, in which case we do
- # # have to go through one layer to get to the first feedforward.
- # # the 'if' is there to rule out increasing when doing a feedforward on
- # # nothing, ie, when making a constant. constants are allowed to be
- # # created on layer 0, they're part of the embedding or the weights that
- # # will use them later or whatever, it's fine
-
- # at least as deep as needed MVs, but no deeper cause FF
- # (which happens at end of layer)
- return format_output(parents_tuple, parents2res, name,
- min_poss_depth=min_poss_depth,
- elementwise_function=elementwise_function,
- from_zipmap=True)
+ name=plain_unfinished_sequence_name):
+ sequences_tuple = tupleise(sequences_tuple)
+ unfinished_parents_tuple = UnfinishedSequencesTuple(
+ sequences_tuple) # this also takes care of turning the
+ # value in sequences_tuple to indeed a tuple of sequences and not eg a
+ # single sequence which will cause weird behaviour later
+
+ parents_tuple = (_input, unfinished_parents_tuple)
+
+ def parents2res(w, vt):
+ return _zipmap(len(w), vt, elementwise_function)
+
+ # feedforward doesn't increase layer
+ min_poss_depth = _min_poss_depth(sequences_tuple)
+ # new assumption, to be revised later: can do arbitrary zipmap even before
+ # first feed-forward, i.e. in build up to first attention. truth is can do
+ # 'simple' zipmap towards first attention (no xor, but yes things like
+ # 'and' or 'indicator for ==' or whatever) based on initial linear
+ # translation done for Q,K in attention (not deep enough for xor, but deep
+ # enough for simple stuff) alongside use of initial embedding. honestly
+ # literally can just put everything in initial embedding if need it so bad
+ # its the first layer and its zipmap its only a function of the token and
+ # indices, so long as its not computing any weird combination between them
+ # you can do it in the embedding
+ # if len(sequences_tuple)>0:
+ # min_poss_depth = max(min_poss_depth,1) # except for the very specific
+ # # case where it is the very first thing to be done, in which case we do
+ # # have to go through one layer to get to the first feedforward.
+ # # the 'if' is there to rule out increasing when doing a feedforward on
+ # # nothing, ie, when making a constant. constants are allowed to be
+ # # created on layer 0, they're part of the embedding or the weights that
+ # # will use them later or whatever, it's fine
+
+ # at least as deep as needed MVs, but no deeper cause FF
+ # (which happens at end of layer)
+ return format_output(parents_tuple, parents2res, name,
+ min_poss_depth=min_poss_depth,
+ elementwise_function=elementwise_function,
+ from_zipmap=True)
def aggregate(select, sequences_tuple, elementwise_function=None,
- default=None, name=plain_unfinished_sequence_name):
- sequences_tuple = tupleise(sequences_tuple)
- definitely_uses_identity_function = None is elementwise_function
- if definitely_uses_identity_function:
- elementwise_function = get_identity_function(len(sequences_tuple))
- unfinished_parents_tuple = UnfinishedSequencesTuple(sequences_tuple)
- parents_tuple = (select, unfinished_parents_tuple)
- def parents2res(s, vt): return _aggregate(
- s, vt, elementwise_function, default=default)
- def_uses = definitely_uses_identity_function
-
- # at least as deep as needed attention and at least one deeper than needed
- # MVs
- return format_output(parents_tuple, parents2res, name,
- elementwise_function=elementwise_function,
- default=default,
- min_poss_depth=max(_min_poss_depth(
- sequences_tuple)+1, select.min_poss_depth),
- definitely_uses_identity_function=def_uses)
+ default=None, name=plain_unfinished_sequence_name):
+ sequences_tuple = tupleise(sequences_tuple)
+ definitely_uses_identity_function = None is elementwise_function
+ if definitely_uses_identity_function:
+ elementwise_function = get_identity_function(len(sequences_tuple))
+ unfinished_parents_tuple = UnfinishedSequencesTuple(sequences_tuple)
+ parents_tuple = (select, unfinished_parents_tuple)
+
+ def parents2res(s, vt):
+ return _aggregate(s, vt, elementwise_function, default=default)
+
+ def_uses = definitely_uses_identity_function
+
+ # at least as deep as needed attention and at least one deeper than needed
+ # MVs
+ min_poss_depth = max(_min_poss_depth(sequences_tuple) + 1,
+ select.min_poss_depth)
+ return format_output(parents_tuple, parents2res, name,
+ elementwise_function=elementwise_function,
+ default=default,
+ min_poss_depth=min_poss_depth,
+ definitely_uses_identity_function=def_uses)
# up to here was just plain transformer 'assembly'. any addition is a lie
@@ -572,18 +579,18 @@ def parents2res(s, vt): return _aggregate(
def UnfinishedSequenceFunc(f):
- setattr(UnfinishedSequence, f.__name__, f)
+ setattr(UnfinishedSequence, f.__name__, f)
def UnfinishedFunc(f):
- setattr(Unfinished, f.__name__, f)
+ setattr(Unfinished, f.__name__, f)
@UnfinishedSequenceFunc
def allow_suppressing_display(self):
- self.always_display = False
- return self # return self to allow chaining with other calls and throwing
- # straight into a return statement etc
+ self.always_display = False
+ return self # return self to allow chaining with other calls and throwing
+ # straight into a return statement etc
# later, we will overload == for unfinished sequences, such that it always
# returns another unfinished sequence. unfortunately this creates the following
@@ -595,15 +602,15 @@ def allow_suppressing_display(self):
def guarded_compare(seq1, seq2):
- if isinstance(seq1, UnfinishedSequence) \
- or isinstance(seq2, UnfinishedSequence):
- return seq1 is seq2
- return seq1 == seq2
+ if isinstance(seq1, UnfinishedSequence) \
+ or isinstance(seq2, UnfinishedSequence):
+ return seq1 is seq2
+ return seq1 == seq2
def guarded_contains(ll, a):
- if isinstance(a, Unfinished):
- return True in [(a is e) for e in ll]
- else:
- ll = [e for e in ll if not isinstance(e, Unfinished)]
- return a in ll
+ if isinstance(a, Unfinished):
+ return True in [(a is e) for e in ll]
+ else:
+ ll = [e for e in ll if not isinstance(e, Unfinished)]
+ return a in ll
diff --git a/RASP_support/REPL.py b/RASP_support/REPL.py
index 50cd3a4..3ed347b 100644
--- a/RASP_support/REPL.py
+++ b/RASP_support/REPL.py
@@ -6,7 +6,7 @@
from .Environment import Environment, UndefinedVariable, ReservedName
from .FunctionalSupport import UnfinishedSequence, UnfinishedSelect, Unfinished
from .Evaluator import Evaluator, NamedVal, NamedValList, JustVal, \
- RASPFunction, ArgsError, RASPTypeError, RASPValueError
+ RASPFunction, ArgsError, RASPTypeError, RASPValueError
from .Support import Select, Sequence, lazy_type_check
from termcolor import colored
from .colors import error_color, values_color, general_color
@@ -15,586 +15,593 @@
class ResultToPrint:
- def __init__(self, res, to_print):
- self.res, self.print = res, to_print
+ def __init__(self, res, to_print):
+ self.res, self.print = res, to_print
class LazyPrint:
- def __init__(self, *a, **kw):
- self.a, self.kw = a, kw
+ def __init__(self, *a, **kw):
+ self.a, self.kw = a, kw
- def print(self):
- print(*self.a, **self.kw)
+ def print(self):
+ print(*self.a, **self.kw)
class StopException(Exception):
- def __init__(self):
- super().__init__()
+ def __init__(self):
+ super().__init__()
DEBUG = False
def debprint(*a, **kw):
- if DEBUG:
- coloredprint(*a, **kw)
+ if DEBUG:
+ coloredprint(*a, **kw)
class ReturnExample:
- def __init__(self, subset):
- self.subset = subset
+ def __init__(self, subset):
+ self.subset = subset
class LoadError(Exception):
- def __init__(self, msg):
- super().__init__(msg)
+ def __init__(self, msg):
+ super().__init__(msg)
def is_comment(line):
- if not isinstance(line, str):
- return False
- return line.strip().startswith("#")
+ if not isinstance(line, str):
+ return False
+ return line.strip().startswith("#")
def formatstr(res):
- if isinstance(res, str):
- return "\""+res+"\""
- return str(res)
+ if isinstance(res, str):
+ return "\"" + res + "\""
+ return str(res)
class REPL:
- def __init__(self):
- self.env = Environment(name="console")
- self.sequence_running_example = "hello"
- self.selector_running_example = "hello"
- self.sequence_prints_verbose = False
- self.show_sequence_examples = True
- self.show_selector_examples = True
- self.results_to_print = []
- self.print_welcome()
- self.load_base_libraries_and_make_base_env()
-
- def load_base_libraries_and_make_base_env(self):
- self.silent = True
- # base env: the env from which every load begins
- self.base_env = self.env.snapshot()
- # bootstrap base_env with current (basically empty except indices etc)
- # env, then load the base libraries to build the actual base env
- # make the library-loaded variables and functions not-overwriteable
- self.env.storing_in_constants = True
- for lib in ["RASP_support/rasplib"]:
- self.run_given_line("load \"" + lib + "\";")
- self.base_env = self.env.snapshot()
- self.env.storing_in_constants = False
- self.run_given_line("tokens=tokens_str;")
- self.base_env = self.env.snapshot()
- self.silent = False
-
- def set_running_example(self, example, which="both"):
- if which in ["both", ENCODER_NAME]:
- self.sequence_running_example = example
- if which in ["both", "selector"]:
- self.selector_running_example = example
-
- def print_welcome(self):
- print(colored("RASP 0.1", general_color))
- print(colored("running example is:", general_color),
- colored(self.sequence_running_example, values_color))
-
- def print_just_val(self, justval):
- val = justval.val
- if None is val:
- return
- if isinstance(val, Select):
- print(colored("\t = ", general_color))
- print_select(val.created_from_input, val)
- elif isinstance(val, Sequence) and self.sequence_prints_verbose:
- print(colored("\t = ", general_color), end="")
- print_seq(val.created_from_input, val, still_on_prev_line=True)
- else:
- print(colored("\t = ", general_color),
- colored(str(val).replace("\n", "\n\t\t\t"), values_color))
-
- def print_named_val(self, name, val, ntabs=0, extra_first_pref=""):
- pref = "\t"*ntabs
- if (None is name) and isinstance(val, Unfinished):
- name = val.name
- if isinstance(val, UnfinishedSequence):
- print(pref,
- colored(extra_first_pref, general_color),
- colored(" "+ENCODER_NAME+":", general_color),
- colored(name, general_color))
- if self.show_sequence_examples:
- if self.sequence_prints_verbose:
- print(colored(f"{pref} \t Example:", general_color),
- end="")
- optional_exampledesc =\
- colored(name + "(", general_color) +\
- colored(formatstr(self.sequence_running_example),
- values_color) +\
- colored(") =", general_color)
- print_seq(self.selector_running_example,
- val.call(self.sequence_running_example),
- still_on_prev_line=True,
- extra_pref=pref,
- lastpref_if_shortprint=optional_exampledesc)
- else:
- print(colored(f"{pref} \t Example: {name}(",
- general_color) +
- colored(formatstr(self.sequence_running_example), values_color) +
- colored(") =", general_color),
- val.call(self.sequence_running_example))
- elif isinstance(val, UnfinishedSelect):
- print(colored(pref, general_color),
- colored(extra_first_pref, general_color),
- colored(f" selector: {name}", general_color))
- if self.show_selector_examples:
- print(colored(f"{pref} \t Example:", general_color))
- print_select(self.selector_running_example, val.call(
- self.selector_running_example), extra_pref=pref)
- elif isinstance(val, RASPFunction):
- print(colored(f"{pref} {extra_first_pref} ", general_color) +
- colored(str(val), general_color))
- elif isinstance(val, list):
- named = " list: "+((name+" = ") if name is not None else "")
- print(colored(f"{pref} {extra_first_pref} {named}",
- general_color), end="")
- flat = True not in [isinstance(v, list) or isinstance(
- v, dict) or isinstance(v, Unfinished) for v in val]
- if flat:
- print(colored(val, values_color))
- else:
- print(colored(f"{pref} [", general_color))
- for v in val:
- self.print_named_val(None, v, ntabs=ntabs+2)
- print(colored(str(pref) + " "*(len(named) +2) + "]",
- general_color))
- elif isinstance(val, dict):
- named = " dict: "+((name+" = ") if name is not None else "")
- print(colored(f"{pref} {extra_first_pref} {named}",
- general_color), end="")
- flat = True not in [isinstance(val[v], list) or isinstance(
- val[v], dict) or isinstance(val[v], Unfinished) for v in val]
- if flat:
- print(colored(val, values_color))
- else:
- print(colored(str(pref) + " {", general_color))
- for v in val:
- self.print_named_val(None, val[v], ntabs=ntabs + 3,
- extra_first_pref=formatstr(v) + " : ")
- print(colored(str(pref) + " "*(len(named) + 2) + "}",
- general_color))
-
- else:
- namestr = (name + " = ") if name is not None else ""
- print(colored(f"{pref} value: {namestr}", general_color),
- colored(formatstr(val), values_color))
-
- def print_example(self, nres):
- if nres.subset in ["both", ENCODER_NAME]:
- print(colored("\t"+ENCODER_NAME+" example:", general_color),
- colored(formatstr(self.sequence_running_example), values_color))
- if nres.subset in ["both", "selector"]:
- print(colored("\tselector example:", general_color),
- colored(formatstr(self.selector_running_example), values_color))
-
- def print_result(self, rp):
- if self.silent:
- return
- if isinstance(rp, LazyPrint):
- return rp.print()
- # a list of multiple ResultToPrint s -- probably the result of a
- # multi-assignment
- if isinstance(rp, list):
- for v in rp:
- self.print_result(v)
- return
- if not rp.print:
- return
- res = rp.res
- if isinstance(res, NamedVal):
- self.print_named_val(res.name, res.val)
- elif isinstance(res, ReturnExample):
- self.print_example(res)
- elif isinstance(res, JustVal):
- self.print_just_val(res)
-
- def evaluate_replstatement(self, ast):
- if ast.setExample():
- return ResultToPrint(self.setExample(ast.setExample()), False)
- if ast.showExample():
- return ResultToPrint(self.showExample(ast.showExample()), True)
- if ast.toggleExample():
- return ResultToPrint(self.toggleExample(ast.toggleExample()),
- False)
- if ast.toggleSeqVerbose():
- return ResultToPrint(self.toggleSeqVerbose(ast.toggleSeqVerbose()),
- False)
- if ast.exit():
- raise StopException()
-
- def toggleSeqVerbose(self, ast):
- switch = ast.switch.text
- self.sequence_prints_verbose = switch == "on"
-
- def toggleExample(self, ast):
- subset = ast.subset
- subset = "both" if not subset else subset.text
- switch = ast.switch.text
- examples_on = switch == "on"
- if subset in ["both", ENCODER_NAME]:
- self.show_sequence_examples = examples_on
- if subset in ["both", "selector"]:
- self.show_selector_examples = examples_on
-
- def showExample(self, ast):
- subset = ast.subset
- subset = "both" if not subset else subset.text
- return ReturnExample(subset)
-
- def setExample(self, ast):
- example = Evaluator(self.env, self).evaluateExpr(ast.example)
- if not isinstance(example, Iterable):
- raise RASPTypeError("example not iterable: "+str(example))
- subset = ast.subset
- subset = "both" if not subset else subset.text
- self.set_running_example(example, subset)
- return ReturnExample(subset)
-
- def loadFile(self, ast, calling_env=None):
- if None is calling_env:
- calling_env = self.env
- libname = ast.filename.text[1:-1]
- filename = libname + ".rasp"
- try:
- with open(filename, "r") as f:
- prev_example_settings = self.show_sequence_examples, \
- self.show_selector_examples
- self.show_sequence_examples = False
- self.show_selector_examples = False
- self.run(fromfile=f,
- env=Environment(name=libname,
- parent_env=self.base_env,
- stealing_env=calling_env),
- store_prints=True)
- self.filter_and_dump_prints()
- self.show_sequence_examples, self.show_selector_examples = \
- prev_example_settings
- except FileNotFoundError:
- raise LoadError("could not find file: "+filename)
-
- def get_tree(self, fromfile=None):
- try:
- return LineReader(fromfile=fromfile).get_input_tree()
- except AntlrException as e:
- print(colored(f"\t!! antlr exception: {e.msg} \t-- ignoring input",
- error_color))
- return None
-
- def run_given_line(self, line):
- try:
- tree = LineReader(given_line=line).get_input_tree()
- if isinstance(tree, Stop):
- return None
- rp = self.evaluate_tree(tree)
- if isinstance(rp, LazyPrint):
- # error messages get raised, but ultimately have to be printed
- # somewhere if not caught? idk
- rp.print()
- except AntlrException as e:
- print(colored(f"\t!! REPL failed to run initiating line: {line}",
- error_color))
- print(colored(f"\t --got antlr exception: {e.msg}",
- error_color))
- return None
-
- def assigned_to_top(self, res, env):
- if env is self.env:
- return True
- # we are now definitely inside some file, the question is whether we
- # have taken the result and kept it in the top level too, i.e., whether
- # we have imported a non-private value. checking whether it is also in
- # self.env, even identical, will not tell us much as it may have been
- # here and the same already. so we have to replicate the logic here.
- if not isinstance(res, NamedVal):
- return False # only namedvals get set to begin with
- if res.name.startswith("_") or (res.name == "out"):
- return False
- return True
-
- def evaluate_tree(self, tree, env=None):
- if None is env:
- env = self.env # otherwise, can pass custom env
- # (e.g. when loading from a file, make env for that file,
- # to keep that file's private (i.e. underscore-prefixed) variables
- # to itself)
- if None is tree:
- return ResultToPrint(None, False)
- try:
- if tree.replstatement():
- return self.evaluate_replstatement(tree.replstatement())
- elif tree.raspstatement():
- res = Evaluator(env, self).evaluate(tree.raspstatement())
- if isinstance(res, NamedValList):
- return [ResultToPrint(r, self.assigned_to_top(r, env)) for
- r in res.nvs]
- return ResultToPrint(res, self.assigned_to_top(res, env))
- except (UndefinedVariable, ReservedName) as e:
- return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}", error_color))
- except NotImplementedError:
- return LazyPrint(
- colored(f"not implemented this command yet! ignoring", error_color))
- except (ArgsError, RASPTypeError, LoadError, RASPValueError) as e:
- return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}", error_color))
- # if not replstatement or raspstatement, then comment
- return ResultToPrint(None, False)
-
- def filter_and_dump_prints(self):
- # TODO: some error messages are still rising up and getting printed
- # before reaching this position :(
- def filter_named_val_reps(rps):
- # do the filtering. no namedvallists here - those are converted
- # into a list of ResultToPrint s containing NamedVal s immediately
- # after receiving them in evaluate_tree
- res = []
- names = set()
- # go backwards - want to print the last occurence of each named
- # item, not first, so filter works backwards
- for r in rps[::-1]:
- if isinstance(r.res, NamedVal):
- if r.res.name in names:
- continue
- names.add(r.res.name)
- res.append(r)
- return res[::-1] # flip back forwards
-
- if True not in [isinstance(v, LazyPrint) for
- v in self.results_to_print]:
- self.results_to_print = filter_named_val_reps(
- self.results_to_print)
- # if isinstance(res,NamedVal):
- # self.print_named_val(res.name,res.val)
- #
- # print all that needs to be printed:
- for r in self.results_to_print:
- if isinstance(r, LazyPrint):
- r.print()
- else:
- self.print_result(r)
- # clear the list
- self.results_to_print = []
-
- def run(self, fromfile=None, env=None, store_prints=False):
- def careful_print(*a, **kw):
- if store_prints:
- self.results_to_print.append(LazyPrint(*a, **kw))
- else:
- print(*a, **kw)
- while True:
- try:
- tree = self.get_tree(fromfile)
- if isinstance(tree, Stop):
- break
- rp = self.evaluate_tree(tree, env)
- if store_prints:
- if isinstance(rp, list):
- # multiple results given - a multi-assignment
- self.results_to_print += rp
- else:
- self.results_to_print.append(rp)
- else:
- self.print_result(rp)
- except RASPTypeError as e:
- msg = "\t!!statement executed, but result fails on evaluation:"
- msg += "\n\t\t"
- toprint = colored(f"{msg} {e}", error_color)
- careful_print(toprint)
- except EOFError:
- careful_print("")
- break
- except StopException:
- break
- except KeyboardInterrupt:
- careful_print("") # makes newline
- except Exception as e:
- if DEBUG:
- raise e
- careful_print(colored(f"something went wrong: {e}",
- error_color))
+ def __init__(self):
+ self.env = Environment(name="console")
+ self.sequence_running_example = "hello"
+ self.selector_running_example = "hello"
+ self.sequence_prints_verbose = False
+ self.show_sequence_examples = True
+ self.show_selector_examples = True
+ self.results_to_print = []
+ self.print_welcome()
+ self.load_base_libraries_and_make_base_env()
+
+ def load_base_libraries_and_make_base_env(self):
+ self.silent = True
+ # base env: the env from which every load begins
+ self.base_env = self.env.snapshot()
+ # bootstrap base_env with current (basically empty except indices etc)
+ # env, then load the base libraries to build the actual base env
+ # make the library-loaded variables and functions not-overwriteable
+ self.env.storing_in_constants = True
+ for lib in ["RASP_support/rasplib"]:
+ self.run_given_line("load \"" + lib + "\";")
+ self.base_env = self.env.snapshot()
+ self.env.storing_in_constants = False
+ self.run_given_line("tokens=tokens_str;")
+ self.base_env = self.env.snapshot()
+ self.silent = False
+
+ def set_running_example(self, example, which="both"):
+ if which in ["both", ENCODER_NAME]:
+ self.sequence_running_example = example
+ if which in ["both", "selector"]:
+ self.selector_running_example = example
+
+ def print_welcome(self):
+ print(colored("RASP 0.1", general_color))
+ print(colored("running example is:", general_color),
+ colored(self.sequence_running_example, values_color))
+
+ def print_just_val(self, justval):
+ val = justval.val
+ if None is val:
+ return
+ if isinstance(val, Select):
+ print(colored("\t = ", general_color))
+ print_select(val.created_from_input, val)
+ elif isinstance(val, Sequence) and self.sequence_prints_verbose:
+ print(colored("\t = ", general_color), end="")
+ print_seq(val.created_from_input, val, still_on_prev_line=True)
+ else:
+ print(colored("\t = ", general_color),
+ colored(str(val).replace("\n", "\n\t\t\t"), values_color))
+
+ def print_named_val(self, name, val, ntabs=0, extra_first_pref=""):
+ pref = "\t" * ntabs
+ if (None is name) and isinstance(val, Unfinished):
+ name = val.name
+ if isinstance(val, UnfinishedSequence):
+ print(pref,
+ colored(extra_first_pref, general_color),
+ colored(" " + ENCODER_NAME + ":", general_color),
+ colored(name, general_color))
+ if self.show_sequence_examples:
+ if self.sequence_prints_verbose:
+ print(colored(f"{pref} \t Example:", general_color),
+ end="")
+ optional_exampledesc =\
+ colored(name + "(", general_color) +\
+ colored(formatstr(self.sequence_running_example),
+ values_color) +\
+ colored(") =", general_color)
+ print_seq(self.selector_running_example,
+ val.call(self.sequence_running_example),
+ still_on_prev_line=True,
+ extra_pref=pref,
+ lastpref_if_shortprint=optional_exampledesc)
+ else:
+ print(colored(f"{pref} \t Example: {name}(",
+ general_color) +
+ colored(formatstr(self.sequence_running_example),
+ values_color) +
+ colored(") =", general_color),
+ val.call(self.sequence_running_example))
+ elif isinstance(val, UnfinishedSelect):
+ print(colored(pref, general_color),
+ colored(extra_first_pref, general_color),
+ colored(f" selector: {name}", general_color))
+ if self.show_selector_examples:
+ print(colored(f"{pref} \t Example:", general_color))
+ print_select(self.selector_running_example, val.call(
+ self.selector_running_example), extra_pref=pref)
+ elif isinstance(val, RASPFunction):
+ print(colored(f"{pref} {extra_first_pref} ", general_color) +
+ colored(str(val), general_color))
+ elif isinstance(val, list):
+ named = " list: " + ((name + " = ") if name is not None else "")
+ print(colored(f"{pref} {extra_first_pref} {named}",
+ general_color), end="")
+ flat = True not in [isinstance(v, list) or isinstance(
+ v, dict) or isinstance(v, Unfinished) for v in val]
+ if flat:
+ print(colored(val, values_color))
+ else:
+ print(colored(f"{pref} [", general_color))
+ for v in val:
+ self.print_named_val(None, v, ntabs=ntabs + 2)
+ print(colored(str(pref) + " " * (len(named) + 2) + "]",
+ general_color))
+ elif isinstance(val, dict):
+ named = " dict: " + ((name + " = ") if name is not None else "")
+ print(colored(f"{pref} {extra_first_pref} {named}",
+ general_color), end="")
+ flat = True not in [isinstance(val[v], list) or isinstance(
+ val[v], dict) or isinstance(val[v], Unfinished) for v in val]
+ if flat:
+ print(colored(val, values_color))
+ else:
+ print(colored(str(pref) + " {", general_color))
+ for v in val:
+ self.print_named_val(None, val[v], ntabs=ntabs + 3,
+ extra_first_pref=formatstr(v) + " : ")
+ print(colored(str(pref) + " " * (len(named) + 2) + "}",
+ general_color))
+
+ else:
+ namestr = (name + " = ") if name is not None else ""
+ print(colored(f"{pref} value: {namestr}", general_color),
+ colored(formatstr(val), values_color))
+
+ def print_example(self, nres):
+ if nres.subset in ["both", ENCODER_NAME]:
+ print(colored("\t" + ENCODER_NAME + " example:", general_color),
+ colored(formatstr(self.sequence_running_example),
+ values_color))
+ if nres.subset in ["both", "selector"]:
+ print(colored("\tselector example:", general_color),
+ colored(formatstr(self.selector_running_example),
+ values_color))
+
+ def print_result(self, rp):
+ if self.silent:
+ return
+ if isinstance(rp, LazyPrint):
+ return rp.print()
+ # a list of multiple ResultToPrint s -- probably the result of a
+ # multi-assignment
+ if isinstance(rp, list):
+ for v in rp:
+ self.print_result(v)
+ return
+ if not rp.print:
+ return
+ res = rp.res
+ if isinstance(res, NamedVal):
+ self.print_named_val(res.name, res.val)
+ elif isinstance(res, ReturnExample):
+ self.print_example(res)
+ elif isinstance(res, JustVal):
+ self.print_just_val(res)
+
+ def evaluate_replstatement(self, ast):
+ if ast.setExample():
+ return ResultToPrint(self.setExample(ast.setExample()), False)
+ if ast.showExample():
+ return ResultToPrint(self.showExample(ast.showExample()), True)
+ if ast.toggleExample():
+ return ResultToPrint(self.toggleExample(ast.toggleExample()),
+ False)
+ if ast.toggleSeqVerbose():
+ return ResultToPrint(self.toggleSeqVerbose(ast.toggleSeqVerbose()),
+ False)
+ if ast.exit():
+ raise StopException()
+
+ def toggleSeqVerbose(self, ast):
+ switch = ast.switch.text
+ self.sequence_prints_verbose = switch == "on"
+
+ def toggleExample(self, ast):
+ subset = ast.subset
+ subset = "both" if not subset else subset.text
+ switch = ast.switch.text
+ examples_on = switch == "on"
+ if subset in ["both", ENCODER_NAME]:
+ self.show_sequence_examples = examples_on
+ if subset in ["both", "selector"]:
+ self.show_selector_examples = examples_on
+
+ def showExample(self, ast):
+ subset = ast.subset
+ subset = "both" if not subset else subset.text
+ return ReturnExample(subset)
+
+ def setExample(self, ast):
+ example = Evaluator(self.env, self).evaluateExpr(ast.example)
+ if not isinstance(example, Iterable):
+ raise RASPTypeError("example not iterable: " + str(example))
+ subset = ast.subset
+ subset = "both" if not subset else subset.text
+ self.set_running_example(example, subset)
+ return ReturnExample(subset)
+
+ def loadFile(self, ast, calling_env=None):
+ if None is calling_env:
+ calling_env = self.env
+ libname = ast.filename.text[1:-1]
+ filename = libname + ".rasp"
+ try:
+ with open(filename, "r") as f:
+ prev_example_settings = self.show_sequence_examples, \
+ self.show_selector_examples
+ self.show_sequence_examples = False
+ self.show_selector_examples = False
+ self.run(fromfile=f,
+ env=Environment(name=libname,
+ parent_env=self.base_env,
+ stealing_env=calling_env),
+ store_prints=True)
+ self.filter_and_dump_prints()
+ self.show_sequence_examples, self.show_selector_examples = \
+ prev_example_settings
+ except FileNotFoundError:
+ raise LoadError("could not find file: " + filename)
+
+ def get_tree(self, fromfile=None):
+ try:
+ return LineReader(fromfile=fromfile).get_input_tree()
+ except AntlrException as e:
+ print(colored(f"\t!! antlr exception: {e.msg} \t-- ignoring input",
+ error_color))
+ return None
+
+ def run_given_line(self, line):
+ try:
+ tree = LineReader(given_line=line).get_input_tree()
+ if isinstance(tree, Stop):
+ return None
+ rp = self.evaluate_tree(tree)
+ if isinstance(rp, LazyPrint):
+ # error messages get raised, but ultimately have to be printed
+ # somewhere if not caught? idk
+ rp.print()
+ except AntlrException as e:
+ print(colored(f"\t!! REPL failed to run initiating line: {line}",
+ error_color))
+ print(colored(f"\t --got antlr exception: {e.msg}",
+ error_color))
+ return None
+
+ def assigned_to_top(self, res, env):
+ if env is self.env:
+ return True
+ # we are now definitely inside some file, the question is whether we
+ # have taken the result and kept it in the top level too, i.e., whether
+ # we have imported a non-private value. checking whether it is also in
+ # self.env, even identical, will not tell us much as it may have been
+ # here and the same already. so we have to replicate the logic here.
+ if not isinstance(res, NamedVal):
+ return False # only namedvals get set to begin with
+ if res.name.startswith("_") or (res.name == "out"):
+ return False
+ return True
+
+ def evaluate_tree(self, tree, env=None):
+ if None is env:
+ env = self.env # otherwise, can pass custom env
+ # (e.g. when loading from a file, make env for that file,
+ # to keep that file's private (i.e. underscore-prefixed) variables
+ # to itself)
+ if None is tree:
+ return ResultToPrint(None, False)
+ try:
+ if tree.replstatement():
+ return self.evaluate_replstatement(tree.replstatement())
+ elif tree.raspstatement():
+ res = Evaluator(env, self).evaluate(tree.raspstatement())
+ if isinstance(res, NamedValList):
+ return [ResultToPrint(r, self.assigned_to_top(r, env)) for
+ r in res.nvs]
+ return ResultToPrint(res, self.assigned_to_top(res, env))
+ except (UndefinedVariable, ReservedName) as e:
+ return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}",
+ error_color))
+ except NotImplementedError:
+ return LazyPrint(
+ colored(f"not implemented this command yet! ignoring",
+ error_color))
+ except (ArgsError, RASPTypeError, LoadError, RASPValueError) as e:
+ return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}",
+ error_color))
+ # if not replstatement or raspstatement, then comment
+ return ResultToPrint(None, False)
+
+ def filter_and_dump_prints(self):
+ # TODO: some error messages are still rising up and getting printed
+ # before reaching this position :(
+ def filter_named_val_reps(rps):
+ # do the filtering. no namedvallists here - those are converted
+ # into a list of ResultToPrint s containing NamedVal s immediately
+ # after receiving them in evaluate_tree
+ res = []
+ names = set()
+ # go backwards - want to print the last occurence of each named
+ # item, not first, so filter works backwards
+ for r in rps[::-1]:
+ if isinstance(r.res, NamedVal):
+ if r.res.name in names:
+ continue
+ names.add(r.res.name)
+ res.append(r)
+ return res[::-1] # flip back forwards
+
+ if True not in [isinstance(v, LazyPrint) for
+ v in self.results_to_print]:
+ self.results_to_print = filter_named_val_reps(
+ self.results_to_print)
+ # if isinstance(res,NamedVal):
+ # self.print_named_val(res.name,res.val)
+ #
+ # print all that needs to be printed:
+ for r in self.results_to_print:
+ if isinstance(r, LazyPrint):
+ r.print()
+ else:
+ self.print_result(r)
+ # clear the list
+ self.results_to_print = []
+
+ def run(self, fromfile=None, env=None, store_prints=False):
+ def careful_print(*a, **kw):
+ if store_prints:
+ self.results_to_print.append(LazyPrint(*a, **kw))
+ else:
+ print(*a, **kw)
+ while True:
+ try:
+ tree = self.get_tree(fromfile)
+ if isinstance(tree, Stop):
+ break
+ rp = self.evaluate_tree(tree, env)
+ if store_prints:
+ if isinstance(rp, list):
+ # multiple results given - a multi-assignment
+ self.results_to_print += rp
+ else:
+ self.results_to_print.append(rp)
+ else:
+ self.print_result(rp)
+ except RASPTypeError as e:
+ msg = "\t!!statement executed, but result fails on evaluation:"
+ msg += "\n\t\t"
+ toprint = colored(f"{msg} {e}", error_color)
+ careful_print(toprint)
+ except EOFError:
+ careful_print("")
+ break
+ except StopException:
+ break
+ except KeyboardInterrupt:
+ careful_print("") # makes newline
+ except Exception as e:
+ if DEBUG:
+ raise e
+ careful_print(colored(f"something went wrong: {e}",
+ error_color))
class AntlrException(Exception):
- def __init__(self, msg):
- self.msg = msg
+ def __init__(self, msg):
+ self.msg = msg
class InputNotFinished(Exception):
- def __init__(self):
- pass
+ def __init__(self):
+ pass
class MyErrorListener(ErrorListener):
- def __init__(self):
- super(MyErrorListener, self).__init__()
-
- def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
- if offendingSymbol and offendingSymbol.text == "":
- raise InputNotFinished()
- if msg.startswith("missing ';' at"):
- raise InputNotFinished()
- # TODO: why did this do nothing?
- # if "mismatched input" in msg:
- # a = str(offendingSymbol)
- # b = a[a.find("=")+2:]
- # c = b[:b.find(",<")-1]
- ae = AntlrException(msg)
- ae.recognizer = recognizer
- ae.offendingSymbol = offendingSymbol
- ae.line = line
- ae.column = column
- ae.msg = msg
- ae.e = e
- raise ae
-
- # def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact,
- # ambigAlts, configs):
- # raise AntlrException("ambiguity")
-
- # def reportAttemptingFullContext(self, recognizer, dfa, startIndex,
- # stopIndex, conflictingAlts, configs):
- # we're ok with this: happens with func defs it seems
-
- # def reportContextSensitivity(self, recognizer, dfa, startIndex,
- # stopIndex, prediction, configs):
- # we're ok with this: happens with func defs it seems
+ def __init__(self):
+ super(MyErrorListener, self).__init__()
+
+ def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
+ if offendingSymbol and offendingSymbol.text == "":
+ raise InputNotFinished()
+ if msg.startswith("missing ';' at"):
+ raise InputNotFinished()
+ # TODO: why did this do nothing?
+ # if "mismatched input" in msg:
+ # a = str(offendingSymbol)
+ # b = a[a.find("=")+2:]
+ # c = b[:b.find(",<")-1]
+ ae = AntlrException(msg)
+ ae.recognizer = recognizer
+ ae.offendingSymbol = offendingSymbol
+ ae.line = line
+ ae.column = column
+ ae.msg = msg
+ ae.e = e
+ raise ae
+
+ # def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact,
+ # ambigAlts, configs):
+ # raise AntlrException("ambiguity")
+
+ # def reportAttemptingFullContext(self, recognizer, dfa, startIndex,
+ # stopIndex, conflictingAlts, configs):
+ # we're ok with this: happens with func defs it seems
+
+ # def reportContextSensitivity(self, recognizer, dfa, startIndex,
+ # stopIndex, prediction, configs):
+ # we're ok with this: happens with func defs it seems
class Stop:
- def __init__(self):
- pass
+ def __init__(self):
+ pass
class LineReader:
- def __init__(self, prompt=">>", fromfile=None, given_line=None):
- self.fromfile = fromfile
- self.given_line = given_line
- self.prompt = prompt + " "
- self.cont_prompt = "."*len(prompt)+" "
-
- def str_to_antlr_parser(self, s):
- antlrinput = InputStream(s)
- lexer = RASPLexer(antlrinput)
- lexer.removeErrorListeners()
- lexer.addErrorListener(MyErrorListener())
- stream = CommonTokenStream(lexer)
- parser = RASPParser(stream)
- parser.removeErrorListeners()
- parser.addErrorListener(MyErrorListener())
- return parser
-
- def read_line(self, continuing=False, nest_depth=0):
- prompt = self.cont_prompt if continuing else self.prompt
- if self.fromfile is not None:
- res = self.fromfile.readline()
- # python files return "" on last line (as opposed to "\n" on empty
- # lines)
- if not res:
- return Stop()
- return res
- if self.given_line is not None:
- res = self.given_line
- self.given_line = Stop()
- return res
- else:
- return input(prompt+(" "*nest_depth))
-
- def get_input_tree(self):
- pythoninput = ""
- multiline = False
- while True:
- nest_depth = pythoninput.split().count("def")
- newinput = self.read_line(continuing=multiline,
- nest_depth=nest_depth)
- if isinstance(newinput, Stop): # input stream ended
- return Stop()
- if is_comment(newinput):
- # don't let comments get in and ruin things somehow
- newinput = ""
- # don't replace newlines here! this is how in-function comments get
- # broken
- pythoninput += newinput
- parser = self.str_to_antlr_parser(pythoninput)
- try:
- res = parser.r().statement()
- if isinstance(res, list):
- # TODO: this seems to happen when there's ambiguity. figure
- # out what is going on!!
- assert len(res) == 1
- res = res[0]
- return res
- except InputNotFinished:
- multiline = True
- pythoninput += " "
+ def __init__(self, prompt=">>", fromfile=None, given_line=None):
+ self.fromfile = fromfile
+ self.given_line = given_line
+ self.prompt = prompt + " "
+ self.cont_prompt = "." * len(prompt) + " "
+
+ def str_to_antlr_parser(self, s):
+ antlrinput = InputStream(s)
+ lexer = RASPLexer(antlrinput)
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(MyErrorListener())
+ stream = CommonTokenStream(lexer)
+ parser = RASPParser(stream)
+ parser.removeErrorListeners()
+ parser.addErrorListener(MyErrorListener())
+ return parser
+
+ def read_line(self, continuing=False, nest_depth=0):
+ prompt = self.cont_prompt if continuing else self.prompt
+ if self.fromfile is not None:
+ res = self.fromfile.readline()
+ # python files return "" on last line (as opposed to "\n" on empty
+ # lines)
+ if not res:
+ return Stop()
+ return res
+ if self.given_line is not None:
+ res = self.given_line
+ self.given_line = Stop()
+ return res
+ else:
+ return input(prompt + (" " * nest_depth))
+
+ def get_input_tree(self):
+ pythoninput = ""
+ multiline = False
+ while True:
+ nest_depth = pythoninput.split().count("def")
+ newinput = self.read_line(continuing=multiline,
+ nest_depth=nest_depth)
+ if isinstance(newinput, Stop): # input stream ended
+ return Stop()
+ if is_comment(newinput):
+ # don't let comments get in and ruin things somehow
+ newinput = ""
+ # don't replace newlines here! this is how in-function comments get
+ # broken
+ pythoninput += newinput
+ parser = self.str_to_antlr_parser(pythoninput)
+ try:
+ res = parser.r().statement()
+ if isinstance(res, list):
+ # TODO: this seems to happen when there's ambiguity. figure
+ # out what is going on!!
+ assert len(res) == 1
+ res = res[0]
+ return res
+ except InputNotFinished:
+ multiline = True
+ pythoninput += " "
def print_seq(example, seq, still_on_prev_line=False, extra_pref="",
- lastpref_if_shortprint=""):
- if len(set(seq.get_vals())) == 1:
- print(extra_pref if not still_on_prev_line else "",
- lastpref_if_shortprint,
- colored(str(seq), values_color), end=" ")
- # when there is only one value, it's nicer to just print that than the
- # full list, verbosity be damned
- print(colored("[skipped full display: identical values]", general_color))
- return
- if still_on_prev_line:
- print("")
-
- seq = seq.get_vals()
-
- def cleanboolslist(seq):
- if isinstance(seq[0], bool):
- tstr = "T" if seq.count(True) <= seq.count(False) else ""
- fstr = "F" if seq.count(False) <= seq.count(True) else ""
- return [tstr if v else fstr for v in seq]
- else:
- return seq
-
- example = cleanboolslist(example)
- seqtype = lazy_type_check(seq)
- seq = cleanboolslist(seq)
- example = [str(v) for v in example]
- seq = [str(v) for v in seq]
- maxlen = max(len(v) for v in example+seq)
-
- def neatline(seq):
- def padded(s):
- return " "*(maxlen-len(s))+s
- return " ".join(padded(v) for v in seq)
- print(extra_pref, colored("\t\tinput: ", general_color),
- colored(neatline(example), values_color), "\t",
- colored("("+lazy_type_check(example)+"s)", general_color))
- print(extra_pref, colored("\t\toutput: ", general_color),
- colored(neatline(seq), values_color), "\t",
- colored("("+seqtype+"s)", general_color))
+ lastpref_if_shortprint=""):
+ if len(set(seq.get_vals())) == 1:
+ print(extra_pref if not still_on_prev_line else "",
+ lastpref_if_shortprint,
+ colored(str(seq), values_color), end=" ")
+ # when there is only one value, it's nicer to just print that than the
+ # full list, verbosity be damned
+ print(colored("[skipped full display: identical values]",
+ general_color))
+ return
+ if still_on_prev_line:
+ print("")
+
+ seq = seq.get_vals()
+
+ def cleanboolslist(seq):
+ if isinstance(seq[0], bool):
+ tstr = "T" if seq.count(True) <= seq.count(False) else ""
+ fstr = "F" if seq.count(False) <= seq.count(True) else ""
+ return [tstr if v else fstr for v in seq]
+ else:
+ return seq
+
+ example = cleanboolslist(example)
+ seqtype = lazy_type_check(seq)
+ seq = cleanboolslist(seq)
+ example = [str(v) for v in example]
+ seq = [str(v) for v in seq]
+ maxlen = max(len(v) for v in example + seq)
+
+ def neatline(seq):
+ def padded(s):
+ return " " * (maxlen - len(s)) + s
+ return " ".join(padded(v) for v in seq)
+ print(extra_pref, colored("\t\tinput: ", general_color),
+ colored(neatline(example), values_color), "\t",
+ colored("(" + lazy_type_check(example) + "s)", general_color))
+ print(extra_pref, colored("\t\toutput: ", general_color),
+ colored(neatline(seq), values_color), "\t",
+ colored("(" + seqtype + "s)", general_color))
def print_select(example, select, extra_pref=""):
- # .replace("\n","\n\t\t\t")
- def nice_matrix_line(m):
- return " ".join("1" if v else " " for v in m)
- print(colored(extra_pref, general_color), "\t\t\t ",
- colored(" ".join(str(v) for v in example), values_color))
- matrix = select.get_vals()
- [print(colored(extra_pref, general_color), "\t\t\t",
- colored(v, values_color),
- colored("|", general_color),
- colored(nice_matrix_line(matrix[m]), values_color))
- for v, m in zip(example, matrix)]
+ # .replace("\n","\n\t\t\t")
+ def nice_matrix_line(m):
+ return " ".join("1" if v else " " for v in m)
+ print(colored(extra_pref, general_color), "\t\t\t ",
+ colored(" ".join(str(v) for v in example), values_color))
+ matrix = select.get_vals()
+ [print(colored(extra_pref, general_color), "\t\t\t",
+ colored(v, values_color),
+ colored("|", general_color),
+ colored(nice_matrix_line(matrix[m]), values_color))
+ for v, m in zip(example, matrix)]
if __name__ == "__main__":
- REPL().run()
+ REPL().run()
# (set debug in this file to True)
@@ -605,10 +612,10 @@ def nice_matrix_line(m):
# import REPL
# REPL.runner()
def runner():
- a = REPL()
- try:
- a.run()
- except Exception as e:
- print(e)
- return a, e
- return a, None
+ a = REPL()
+ try:
+ a.run()
+ except Exception as e:
+ print(e)
+ return a, e
+ return a, None
diff --git a/RASP_support/Sugar.py b/RASP_support/Sugar.py
index da80d51..248c22e 100644
--- a/RASP_support/Sugar.py
+++ b/RASP_support/Sugar.py
@@ -11,44 +11,44 @@
def _apply_unary_op(self, f):
- return zipmap(self, f)
+ return zipmap(self, f)
def _apply_binary_op(self, other, f):
- def seq_and_other_op(self, other, f):
- return zipmap(self, lambda a: f(a, other))
+ def seq_and_other_op(self, other, f):
+ return zipmap(self, lambda a: f(a, other))
- def seq_and_seq_op(self, other_seq, f):
- return zipmap((self, other_seq), f)
- if isinstance(other, _UnfinishedSequence):
- return seq_and_seq_op(self, other, f)
- else:
- return seq_and_other_op(self, other, f)
+ def seq_and_seq_op(self, other_seq, f):
+ return zipmap((self, other_seq), f)
+ if isinstance(other, _UnfinishedSequence):
+ return seq_and_seq_op(self, other, f)
+ else:
+ return seq_and_other_op(self, other, f)
add_ops(_UnfinishedSequence, _apply_unary_op, _apply_binary_op)
def _addname(seq, name, default_name, always_display_when_named=True):
- if name is None:
- res = seq.setname(default_name,
- always_display_when_named=always_display_when_named)
- res = res.allow_suppressing_display()
- else:
- res = seq.setname(name,
- always_display_when_named=always_display_when_named)
- return res
+ if name is None:
+ res = seq.setname(default_name,
+ always_display_when_named=always_display_when_named)
+ res = res.allow_suppressing_display()
+ else:
+ res = seq.setname(name,
+ always_display_when_named=always_display_when_named)
+ return res
full_s = select((), (), lambda: True, name="full average",
- compare_string="full average")
+ compare_string="full average")
def tplconst(v, name=None):
- return _addname(zipmap((), lambda: v), name, "constant: " + str(v),
- always_display_when_named=False).mark_as_constant()
- # always_display_when_named = False : constants aren't worth displaying,
- # but still going to name them in background, in case I change my mind
+ return _addname(zipmap((), lambda: v), name, "constant: " + str(v),
+ always_display_when_named=False).mark_as_constant()
+ # always_display_when_named = False : constants aren't worth displaying,
+ # but still going to name them in background, in case I change my mind
# allow suppressing display for bool, not, and, or : all of these would have
# been boring operators if only python let me overload them
@@ -58,45 +58,45 @@ def tplconst(v, name=None):
def toseq(seq):
- if not isinstance(seq, _UnfinishedSequence):
- seq = tplconst(seq, str(seq))
- return seq
+ if not isinstance(seq, _UnfinishedSequence):
+ seq = tplconst(seq, str(seq))
+ return seq
def asbool(seq):
- res = zipmap(seq, lambda a: bool(a))
- return _addname(res, None, "bool(" + seq.name + ")")
- # would do res = seq==True but it seems this has different behaviour to
- # bool eg 'bool(2)' is True but '2==True' returns False
+ res = zipmap(seq, lambda a: bool(a))
+ return _addname(res, None, "bool(" + seq.name + ")")
+ # would do res = seq==True but it seems this has different behaviour to
+ # bool eg 'bool(2)' is True but '2==True' returns False
def tplnot(seq, name=None):
- # this one does correct conversion using asbool and then we really can just
- # do == False
- pep8hack = False # this avoids violating E712 of PEP8
- res = asbool(seq) == pep8hack
- return _addname(res, name, "( not " + str(seq.name) + " )")
+ # this one does correct conversion using asbool and then we really can just
+ # do == False
+ pep8hack = False # this avoids violating E712 of PEP8
+ res = asbool(seq) == pep8hack
+ return _addname(res, name, "( not " + str(seq.name) + " )")
def _num_trues(left, right):
- l, r = toseq(left), toseq(right)
- return (1 * asbool(l)) + (1 * asbool(r))
+ l, r = toseq(left), toseq(right)
+ return (1 * asbool(l)) + (1 * asbool(r))
def quickname(v):
- if isinstance(v, _Unfinished):
- return v.name
- else:
- return str(v)
+ if isinstance(v, _Unfinished):
+ return v.name
+ else:
+ return str(v)
def tpland(left, right):
- res = _num_trues(left, right) == 2
- return _addname(res, None, "( " + quickname(left) + " and "
- + quickname(right) + ")")
+ res = _num_trues(left, right) == 2
+ return _addname(res, None, "( " + quickname(left) + " and " +
+ quickname(right) + ")")
def tplor(left, right):
- res = _num_trues(left, right) >= 1
- return _addname(res, None, "( " + quickname(left) + " or "
- + quickname(right) + ")")
+ res = _num_trues(left, right) >= 1
+ return _addname(res, None, "( " + quickname(left) + " or " +
+ quickname(right) + ")")
diff --git a/RASP_support/Support.py b/RASP_support/Support.py
index a6de676..2285438 100644
--- a/RASP_support/Support.py
+++ b/RASP_support/Support.py
@@ -6,25 +6,25 @@
class RASPError(Exception):
- def __init__(self, *a):
- super().__init__(" ".join([str(b) for b in a]))
+ def __init__(self, *a):
+ super().__init__(" ".join([str(b) for b in a]))
class RASPTypeError(RASPError):
- def __init__(self, *a):
- super().__init__(*a)
+ def __init__(self, *a):
+ super().__init__(*a)
def clean_val(num, digits=3): # taken from my helper functions
- res = round(num, digits)
- if digits == 0:
- res = int(res)
- return res
+ res = round(num, digits)
+ if digits == 0:
+ res = int(res)
+ return res
class SupportException(Exception):
- def __init__(self, m):
- Exception.__init__(self, m)
+ def __init__(self, m):
+ Exception.__init__(self, m)
TBANNED = "banned"
@@ -33,248 +33,251 @@ def __init__(self, m):
NUMTYPES = [TNAME[int], TNAME[float]]
sorted_typenames_list = sorted(list(TNAME.values()))
legal_types_list_string = ", ".join(
- sorted_typenames_list[:-1])+" or "+sorted_typenames_list[-1]
+ sorted_typenames_list[:-1]) + " or " + sorted_typenames_list[-1]
def is_in_types(v, tlist):
- for t in tlist:
- if isinstance(v, t):
- return True
- return False
+ for t in tlist:
+ if isinstance(v, t):
+ return True
+ return False
def lazy_type_check(vals):
- legal_val_types = [str, bool, int, float]
- number_types = [int, float]
+ legal_val_types = [str, bool, int, float]
+ number_types = [int, float]
- # all vals are same, legal, type:
- for t in legal_val_types:
- b = [isinstance(v, t) for v in vals]
- if False not in b:
- return TNAME[t]
+ # all vals are same, legal, type:
+ for t in legal_val_types:
+ b = [isinstance(v, t) for v in vals]
+ if False not in b:
+ return TNAME[t]
- # allow vals to also be mixed integers and ints, treat those as floats
- # (but don't actually change the ints to floats, want neat printouts)
- b = [is_in_types(v, number_types) for v in vals]
- if False not in b:
- return TNAME[float]
+ # allow vals to also be mixed integers and ints, treat those as floats
+ # (but don't actually change the ints to floats, want neat printouts)
+ b = [is_in_types(v, number_types) for v in vals]
+ if False not in b:
+ return TNAME[float]
- # from here it's all bad, but lets have some clear error messages
- b = [is_in_types(v, legal_val_types) for v in vals]
- if False not in b:
- return TMISMATCHED # all legal types, but mismatched
- else:
- return TBANNED
+ # from here it's all bad, but lets have some clear error messages
+ b = [is_in_types(v, legal_val_types) for v in vals]
+ if False not in b:
+ return TMISMATCHED # all legal types, but mismatched
+ else:
+ return TBANNED
class Sequence:
- def __init__(self, vals):
- self.type = lazy_type_check(vals)
- if self.type == TMISMATCHED:
- raise RASPTypeError(
- "attempted to create sequence with vals of different types:"
- + f"\n\t\t {vals}")
- if self.type == TBANNED:
- raise RASPTypeError(
- "attempted to create sequence with illegal val types "
- + f"(vals must be {legal_types_list_string}):\n\t\t {vals}")
- self._vals = vals
-
- def __str__(self):
- # return "Sequence"+str([small_str(v) for v in self._vals])
- if (len(set(self._vals)) == 1) and (len(self._vals) > 1):
- res = "["+small_str(self._vals[0])+"]*"+str(len(self._vals))
- else:
- res = "["+", ".join(small_str(v) for v in self._vals)+"]"
- return colored(res, values_color) + \
- colored(" ("+self.type+"s)", general_color)
-
- def __repr__(self):
- return str(self)
-
- def __len__(self):
- return len(self._vals)
-
- def get_vals(self):
- return deepcopy(self._vals)
+ def __init__(self, vals):
+ self.type = lazy_type_check(vals)
+ if self.type == TMISMATCHED:
+ raise RASPTypeError(
+ "attempted to create sequence with vals of different types:" +
+ f"\n\t\t {vals}")
+ if self.type == TBANNED:
+ raise RASPTypeError(
+ "attempted to create sequence with illegal val types " +
+ f"(vals must be {legal_types_list_string}):\n\t\t {vals}")
+ self._vals = vals
+
+ def __str__(self):
+ # return "Sequence"+str([small_str(v) for v in self._vals])
+ if (len(set(self._vals)) == 1) and (len(self._vals) > 1):
+ res = "[" + small_str(self._vals[0]) + "]*" + str(len(self._vals))
+ else:
+ res = "[" + ", ".join(small_str(v) for v in self._vals) + "]"
+ return colored(res, values_color) + \
+ colored(" (" + self.type + "s)", general_color)
+
+ def __repr__(self):
+ return str(self)
+
+ def __len__(self):
+ return len(self._vals)
+
+ def get_vals(self):
+ return deepcopy(self._vals)
def dims_match(seqs, expected_dim):
- return False not in [expected_dim == len(seq) for seq in seqs]
+ return False not in [expected_dim == len(seq) for seq in seqs]
class Select:
- def __init__(self, n, q_vars, k_vars, f):
- self.n = n
- self.makeselect(q_vars, k_vars, f)
- self.niceprint = None
-
- def get_vals(self):
- if self.select is None:
- self.makeselect()
- return deepcopy(self.select)
-
- def makeselect(self, q_vars=None, k_vars=None, f=None):
- if None is q_vars:
- assert (None is k_vars) and (None is f)
- q_vars = (Sequence(self.target_index),)
- k_vars = (Sequence(list(range(self.n))),)
- def f(t, i): return t == i
- self.select = {i: [f(*get(q_vars, i), *get(k_vars, j))
- for j in range(self.n)]
- for i in range(self.n)} # outputs of f should be
- # True or False. j goes along input dim, i along output
-
- def __str__(self):
- self.get_vals()
- if None is self.niceprint:
- d = {i: list(map(int, self.select[i])) for i in self.select}
- self.niceprint = str(self.niceprint)
- if len(str(d)) > 40:
- starter = "\n"
- self.niceprint = pprint.pformat(d)
- else:
- starter = ""
- self.niceprint = str(d)
- self.niceprint = starter + self.niceprint
- return self.niceprint
-
- def __repr__(self):
- return str(self)
+ def __init__(self, n, q_vars, k_vars, f):
+ self.n = n
+ self.makeselect(q_vars, k_vars, f)
+ self.niceprint = None
+
+ def get_vals(self):
+ if self.select is None:
+ self.makeselect()
+ return deepcopy(self.select)
+
+ def makeselect(self, q_vars=None, k_vars=None, f=None):
+ if None is q_vars:
+ assert (None is k_vars) and (None is f)
+ q_vars = (Sequence(self.target_index),)
+ k_vars = (Sequence(list(range(self.n))),)
+
+ def f(t, i):
+ return t == i
+
+ self.select = {i: [f(*get(q_vars, i), *get(k_vars, j))
+ for j in range(self.n)]
+ for i in range(self.n)} # outputs of f should be
+ # True or False. j goes along input dim, i along output
+
+ def __str__(self):
+ self.get_vals()
+ if None is self.niceprint:
+ d = {i: list(map(int, self.select[i])) for i in self.select}
+ self.niceprint = str(self.niceprint)
+ if len(str(d)) > 40:
+ starter = "\n"
+ self.niceprint = pprint.pformat(d)
+ else:
+ starter = ""
+ self.niceprint = str(d)
+ self.niceprint = starter + self.niceprint
+ return self.niceprint
+
+ def __repr__(self):
+ return str(self)
def select(n, q_vars, k_vars, f):
- return Select(n, q_vars, k_vars, f)
+ return Select(n, q_vars, k_vars, f)
# applying selects or feedforward (map)
def aggregate(select, k_vars, func, default=None):
- return to_sequences(apply_average_select(select, k_vars, func, default))
+ return to_sequences(apply_average_select(select, k_vars, func, default))
def to_sequences(results_by_index):
- def totup(r):
- if not isinstance(r, tuple):
- return (r,)
- return r
- # convert scalar results to tuples of length 1
- results_by_index = list(map(totup, results_by_index))
- # one list (sequence) per output value
- results_by_output_val = list(zip(*results_by_index))
- res = tuple(map(Sequence, results_by_output_val))
- if len(res) == 1:
- return res[0]
- else:
- return res
+ def totup(r):
+ if not isinstance(r, tuple):
+ return (r,)
+ return r
+ # convert scalar results to tuples of length 1
+ results_by_index = list(map(totup, results_by_index))
+ # one list (sequence) per output value
+ results_by_output_val = list(zip(*results_by_index))
+ res = tuple(map(Sequence, results_by_output_val))
+ if len(res) == 1:
+ return res[0]
+ else:
+ return res
def zipmap(n, k_vars, func):
- # assert len(k_vars) >= 1, "dont make a whole sequence for a plain constant
- # you already know the value of.."
- results_by_index = [func(*get(k_vars, i)) for i in range(n)]
- return to_sequences(results_by_index)
+ # assert len(k_vars) >= 1, "dont make a whole sequence for a plain constant
+ # you already know the value of.."
+ results_by_index = [func(*get(k_vars, i)) for i in range(n)]
+ return to_sequences(results_by_index)
def verify_default_size(default, num_output_vars):
- assert num_output_vars > 0
- if num_output_vars == 1:
- errnote = "aggregates on functions with single output should have" \
- + " scalar default"
- assert not isinstance(default, tuple), errnote
- elif num_output_vars > 1:
- errnote = "for function with >1 output values, default should be" \
- + " tuple of default values, of equal length to passed" \
- + " function's output values (for function with single output" \
- + " value, default should be single value too)"
- check = isinstance(default, tuple) and len(default) == num_output_vars
- assert check, errnote
+ assert num_output_vars > 0
+ if num_output_vars == 1:
+ errnote = "aggregates on functions with single output should have" \
+ + " scalar default"
+ assert not isinstance(default, tuple), errnote
+ elif num_output_vars > 1:
+ errnote = "for function with >1 output values, default should be" \
+ + " tuple of default values, of equal length to passed" \
+ + " function's output values (for function with single output" \
+ + " value, default should be single value too)"
+ check = isinstance(default, tuple) and len(default) == num_output_vars
+ assert check, errnote
def apply_average_select(select, k_vars, func, default=0):
- def apply_func_to_each_index():
- # kvs is list [by index] of lists [by varname] of values
- kvs = [get(k_vars, i) for i in list(range(select.n))]
- candidate_i = [func(*kvi) for kvi in kvs] # candidate output per index
- if num_output_vars > 1:
- candidates_by_varname = list(zip(*candidate_i))
- else:
- # expect tuples of values for conversions in return_sequences
- candidates_by_varname = (candidate_i,)
- return candidates_by_varname
-
- def prep_default(default, num_output_vars):
- if None is default:
- default = 0
- # output of average is always floats, so will be converting all
- # to floats here else we'll fail the lazy type check in the
- # Sequences. (and float(None) doesn't 'compile' )
- # TODO: maybe just lose the lazy type check?
- if not isinstance(default, tuple) and (num_output_vars > 1):
- default = tuple([default]*num_output_vars)
- # *specifically* in apply_average, where values have to be floats,
- # allow default to be single val,
- # that will be repeated for all wanted outputs
- verify_default_size(default, num_output_vars)
- if not isinstance(default, tuple):
- # specifically with how we're going to do things here in the
- # average aggregate, will help to actually have the outputs get
- # passed around as tuples, even if they're scalars really.
- # but do this after the size check for the scalar one so it doesn't
- # get filled with weird ifs... this tupled scalar thing is only a
- # convenience in this implementation in this here function
- default = (default,)
- return default
-
- def apply_and_average_single_index(outputs_by_varname, index,
- index_scores, num_output_vars, default):
- def mean(scores, vals):
- n = scores.count(True) # already >0 by earlier
- if n == 1:
- return vals[scores.index(True)]
- # else # n>1
- if not (lazy_type_check(vals) in NUMTYPES):
- raise Exception(
- "asked to average multiple values, but they are "
- + "non-numbers: " + str(vals))
- return sum([v for s, v in zip(scores, vals) if s])*1.0/n
-
- num_influencers = index_scores.count(True)
- if num_influencers == 0:
- return default
- else:
- # return_sequences expects multiple outputs to be in tuple form
- return tuple(mean(index_scores, o_by_i)
- for o_by_i in outputs_by_varname)
- num_output_vars = get_num_outputs(func(*get(k_vars, 0)))
- candidates_by_varname = apply_func_to_each_index()
- default = prep_default(default, num_output_vars)
- means_per_index = [apply_and_average_single_index(candidates_by_varname,
- i, select.select[i],
- num_output_vars, default)
- for i in range(select.n)]
- # list (per index) of all the new variable values (per varname)
- return means_per_index
+ def apply_func_to_each_index():
+ # kvs is list [by index] of lists [by varname] of values
+ kvs = [get(k_vars, i) for i in list(range(select.n))]
+ candidate_i = [func(*kvi) for kvi in kvs] # candidate output per index
+ if num_output_vars > 1:
+ candidates_by_varname = list(zip(*candidate_i))
+ else:
+ # expect tuples of values for conversions in return_sequences
+ candidates_by_varname = (candidate_i,)
+ return candidates_by_varname
+
+ def prep_default(default, num_output_vars):
+ if None is default:
+ default = 0
+ # output of average is always floats, so will be converting all
+ # to floats here else we'll fail the lazy type check in the
+ # Sequences. (and float(None) doesn't 'compile' )
+ # TODO: maybe just lose the lazy type check?
+ if not isinstance(default, tuple) and (num_output_vars > 1):
+ default = tuple([default] * num_output_vars)
+ # *specifically* in apply_average, where values have to be floats,
+ # allow default to be single val,
+ # that will be repeated for all wanted outputs
+ verify_default_size(default, num_output_vars)
+ if not isinstance(default, tuple):
+ # specifically with how we're going to do things here in the
+ # average aggregate, will help to actually have the outputs get
+ # passed around as tuples, even if they're scalars really.
+ # but do this after the size check for the scalar one so it doesn't
+ # get filled with weird ifs... this tupled scalar thing is only a
+ # convenience in this implementation in this here function
+ default = (default,)
+ return default
+
+ def apply_and_average_single_index(outputs_by_varname, index,
+ index_scores, num_output_vars, default):
+ def mean(scores, vals):
+ n = scores.count(True) # already >0 by earlier
+ if n == 1:
+ return vals[scores.index(True)]
+ # else # n>1
+ if not (lazy_type_check(vals) in NUMTYPES):
+ raise Exception(
+ "asked to average multiple values, but they are " +
+ "non-numbers: " + str(vals))
+ return sum([v for s, v in zip(scores, vals) if s]) * 1.0 / n
+
+ num_influencers = index_scores.count(True)
+ if num_influencers == 0:
+ return default
+ else:
+ # return_sequences expects multiple outputs to be in tuple form
+ return tuple(mean(index_scores, o_by_i)
+ for o_by_i in outputs_by_varname)
+ num_output_vars = get_num_outputs(func(*get(k_vars, 0)))
+ candidates_by_varname = apply_func_to_each_index()
+ default = prep_default(default, num_output_vars)
+ means_per_index = [apply_and_average_single_index(candidates_by_varname,
+ i, select.select[i],
+ num_output_vars, default)
+ for i in range(select.n)]
+ # list (per index) of all the new variable values (per varname)
+ return means_per_index
# user's responsibility to give functions that always have same number of
# outputs
def get_num_outputs(dummy_out):
- if isinstance(dummy_out, tuple):
- return len(dummy_out)
- return 1
+ if isinstance(dummy_out, tuple):
+ return len(dummy_out)
+ return 1
def small_str(v):
- if isinstance(v, float):
- return str(clean_val(v, 3))
- if isinstance(v, bool):
- return "T" if v else "F"
- return str(v)
+ if isinstance(v, float):
+ return str(clean_val(v, 3))
+ if isinstance(v, bool):
+ return "T" if v else "F"
+ return str(v)
def get(vars_list, index): # index should be within range to access
- # v._vals and if not absolutely should raise an error, as it will here
- # by the attempted access
- res = deepcopy([v._vals[index] for v in vars_list])
- return res
+ # v._vals and if not absolutely should raise an error, as it will here
+ # by the attempted access
+ res = deepcopy([v._vals[index] for v in vars_list])
+ return res
diff --git a/RASP_support/analyse.py b/RASP_support/analyse.py
index 1a023fe..be9b187 100644
--- a/RASP_support/analyse.py
+++ b/RASP_support/analyse.py
@@ -1,22 +1,22 @@
from .FunctionalSupport import Unfinished, UnfinishedSequence, \
- UnfinishedSelect, guarded_contains, guarded_compare, zipmap
+ UnfinishedSelect, guarded_contains, guarded_compare, zipmap
from collections import defaultdict, Counter
from copy import copy
def UnfinishedFunc(f):
- setattr(Unfinished, f.__name__, f)
+ setattr(Unfinished, f.__name__, f)
@UnfinishedFunc
def get_parent_sequences(self):
- # for UnfinishedSequences, this should get just the tuple of sequences the
- # aggregate is applied to, and I think in order (as the parents will only
- # be a select and a sequencestuple, and the seqs in the sequencestuple will
- # be added in order and the select will be removed in this function)
+ # for UnfinishedSequences, this should get just the tuple of sequences the
+ # aggregate is applied to, and I think in order (as the parents will only
+ # be a select and a sequencestuple, and the seqs in the sequencestuple will
+ # be added in order and the select will be removed in this function)
- # i.e. drop the selects
- return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)]
+ # i.e. drop the selects
+ return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)]
Unfinished._full_seq_parents = None
@@ -24,95 +24,95 @@ def get_parent_sequences(self):
@UnfinishedFunc
def get_full_seq_parents(self):
- if self._full_seq_parents is None:
- self._full_seq_parents = [u for u in self.get_full_parents()
- if isinstance(u, UnfinishedSequence)]
- return copy(self._full_seq_parents)
+ if self._full_seq_parents is None:
+ self._full_seq_parents = [u for u in self.get_full_parents()
+ if isinstance(u, UnfinishedSequence)]
+ return copy(self._full_seq_parents)
@UnfinishedFunc
def get_parent_select(self):
- if not hasattr(self, "parent_select"):
- real_parents = self.get_parents()
- self.parent_select = next((s for s in real_parents if
- isinstance(s, UnfinishedSelect)), None)
- return self.parent_select
+ if not hasattr(self, "parent_select"):
+ real_parents = self.get_parents()
+ self.parent_select = next((s for s in real_parents if
+ isinstance(s, UnfinishedSelect)), None)
+ return self.parent_select
@UnfinishedFunc
def set_analysis_parent_select(self, options):
- # doesn't really need to be a function but feels clearer visually to have
- # it out here so i can see this variable is being registered to the
- # unfinisheds
- if None is self.parent_select:
- self.analysis_parent_select = self.parent_select
- else:
- getps = (ps for ps in options if
- ps.compare_string == self.get_parent_select().compare_string)
- self.analysis_parent_select = next(getps, None)
- errnote = "parent options given to seq: " + self.name + " did not " \
- + "include anything equivalent to actual seq's parent select (" \
- + self.get_parent_select().compare_string + ")"
- assert self.analysis_parent_select is not None, errnote
+ # doesn't really need to be a function but feels clearer visually to have
+ # it out here so i can see this variable is being registered to the
+ # unfinisheds
+ if None is self.parent_select:
+ self.analysis_parent_select = self.parent_select
+ else:
+ getps = (ps for ps in options if
+ ps.compare_string == self.get_parent_select().compare_string)
+ self.analysis_parent_select = next(getps, None)
+ errnote = "parent options given to seq: " + self.name + " did not " \
+ + "include anything equivalent to actual seq's parent select (" \
+ + self.get_parent_select().compare_string + ")"
+ assert self.analysis_parent_select is not None, errnote
def squeeze_selects(selects):
- compstrs = set([s.compare_string for s in selects])
- if len(compstrs) == len(selects):
- return selects
- return [next(s for s in selects if s.compare_string == cs)
- for cs in compstrs]
+ compstrs = set([s.compare_string for s in selects])
+ if len(compstrs) == len(selects):
+ return selects
+ return [next(s for s in selects if s.compare_string == cs)
+ for cs in compstrs]
@UnfinishedFunc
def schedule(self, scheduler='best', remove_minors=False):
- # recall attentions can be created on level 1 but still generate seqs on
- # level 3 etc hence width is number of *seqs* with different attentions per
- # level.
- def choose_scheduler(scheduler):
- if scheduler == 'best':
- return 'greedy'
- # TODO: implement lastminute, maybe others, and choose narrowest
- # result of all options
- return scheduler
- scheduler = choose_scheduler(scheduler)
- seq_layers = self.greedy_seq_scheduler() if scheduler == 'greedy' \
- else self.lastminute_seq_scheduler()
-
- if remove_minors:
- for i in seq_layers:
- seq_layers[i] = [seq for seq in seq_layers[i] if not seq.is_minor]
-
- def get_seqs_selects(seqs):
- # all the selects needed to compute a set of seqs
- all_selects = set(seq.get_parent_select() for seq in seqs)
- # some of the seqs may not have parent matches,
- # eg, indices. these will return None, which we don't want to count
- all_selects -= set([None])
- return squeeze_selects(all_selects) # squeeze identical parents
-
- layer_selects = {i: get_seqs_selects(seq_layers[i]) for i in seq_layers}
-
- # mark remaining parent select after squeeze
- for i in seq_layers:
- for seq in seq_layers[i]:
- seq.set_analysis_parent_select(layer_selects[i])
-
- return seq_layers, layer_selects
+ # recall attentions can be created on level 1 but still generate seqs on
+ # level 3 etc hence width is number of *seqs* with different attentions per
+ # level.
+ def choose_scheduler(scheduler):
+ if scheduler == 'best':
+ return 'greedy'
+ # TODO: implement lastminute, maybe others, and choose narrowest
+ # result of all options
+ return scheduler
+ scheduler = choose_scheduler(scheduler)
+ seq_layers = self.greedy_seq_scheduler() if scheduler == 'greedy' \
+ else self.lastminute_seq_scheduler()
+
+ if remove_minors:
+ for i in seq_layers:
+ seq_layers[i] = [seq for seq in seq_layers[i] if not seq.is_minor]
+
+ def get_seqs_selects(seqs):
+ # all the selects needed to compute a set of seqs
+ all_selects = set(seq.get_parent_select() for seq in seqs)
+ # some of the seqs may not have parent matches,
+ # eg, indices. these will return None, which we don't want to count
+ all_selects -= set([None])
+ return squeeze_selects(all_selects) # squeeze identical parents
+
+ layer_selects = {i: get_seqs_selects(seq_layers[i]) for i in seq_layers}
+
+ # mark remaining parent select after squeeze
+ for i in seq_layers:
+ for seq in seq_layers[i]:
+ seq.set_analysis_parent_select(layer_selects[i])
+
+ return seq_layers, layer_selects
@UnfinishedFunc
def greedy_seq_scheduler(self):
- all_seqs = sorted(self.get_full_seq_parents(),
- key=lambda seq: seq.creation_order_id)
- # sorting in order of creation automatically sorts by order of in-layer
- # dependencies (i.e. things got through feedforwards), makes prints clearer
- # and eventually is helpful for drawcompflow
- levels = defaultdict(lambda: [])
- for seq in all_seqs:
- # schedule all seqs as early as possible
- levels[seq.min_poss_depth].append(seq)
- return levels
+ all_seqs = sorted(self.get_full_seq_parents(),
+ key=lambda seq: seq.creation_order_id)
+ # sorting in order of creation automatically sorts by order of in-layer
+ # dependencies (i.e. things got through feedforwards), makes prints clearer
+ # and eventually is helpful for drawcompflow
+ levels = defaultdict(lambda: [])
+ for seq in all_seqs:
+ # schedule all seqs as early as possible
+ levels[seq.min_poss_depth].append(seq)
+ return levels
Unfinished.max_poss_depth_for_seq = (None, None)
@@ -120,138 +120,138 @@ def greedy_seq_scheduler(self):
@UnfinishedFunc
def lastminute_for_seq(self, seq):
- raise NotImplementedError
+ raise NotImplementedError
@UnfinishedFunc
def lastminute_seq_scheduler(self):
- all_seqs = self.get_full_seq_parents()
+ all_seqs = self.get_full_seq_parents()
@UnfinishedFunc
def typestr(self):
- if isinstance(self, UnfinishedSelect):
- return "select"
- elif isinstance(self, UnfinishedSequence):
- return "seq"
- else:
- return "internal"
+ if isinstance(self, UnfinishedSelect):
+ return "select"
+ elif isinstance(self, UnfinishedSequence):
+ return "seq"
+ else:
+ return "internal"
@UnfinishedFunc
def width_and_depth(self, scheduler='greedy', loud=True, print_tree_too=False,
- remove_minors=False):
- seq_layers, layer_selects = self.schedule(
- scheduler=scheduler, remove_minors=remove_minors)
- widths = {i: len(layer_selects[i]) for i in layer_selects}
- n_layers = max(seq_layers.keys())
- max_width = max(widths[i] for i in widths)
- if loud:
- print("analysing unfinished", self.typestr()+":", self.name)
- print("using scheduler:", scheduler)
- print("num layers:", n_layers, "max width:", max_width)
- print("width per layer:")
- print("\n".join(str(i)+"\t: "+str(widths[i])
- for i in range(1, n_layers+1)))
- # start from 1 to skip layer 0, which has width 0
- # and is just the inputs (tokens and indices)
- if print_tree_too:
- def print_layer(i, d):
- print(i, "\t:", ", ".join(seq.name for seq in d[i]))
- print("==== seqs at each layer: ====")
- [print_layer(i, seq_layers) for i in range(1, n_layers+1)]
- print("==== selects at each layer: ====")
- [print_layer(i, layer_selects) for i in range(1, n_layers+1)]
- return n_layers, max_width, widths
+ remove_minors=False):
+ seq_layers, layer_selects = self.schedule(
+ scheduler=scheduler, remove_minors=remove_minors)
+ widths = {i: len(layer_selects[i]) for i in layer_selects}
+ n_layers = max(seq_layers.keys())
+ max_width = max(widths[i] for i in widths)
+ if loud:
+ print("analysing unfinished", self.typestr() + ":", self.name)
+ print("using scheduler:", scheduler)
+ print("num layers:", n_layers, "max width:", max_width)
+ print("width per layer:")
+ print("\n".join(str(i) + "\t: " + str(widths[i])
+ for i in range(1, n_layers + 1)))
+ # start from 1 to skip layer 0, which has width 0
+ # and is just the inputs (tokens and indices)
+ if print_tree_too:
+ def print_layer(i, d):
+ print(i, "\t:", ", ".join(seq.name for seq in d[i]))
+ print("==== seqs at each layer: ====")
+ [print_layer(i, seq_layers) for i in range(1, n_layers + 1)]
+ print("==== selects at each layer: ====")
+ [print_layer(i, layer_selects) for i in range(1, n_layers + 1)]
+ return n_layers, max_width, widths
@UnfinishedFunc
def schedule_comp_depth(self, d):
- self.scheduled_comp_depth = d
+ self.scheduled_comp_depth = d
@UnfinishedFunc
def get_all_ancestor_heads_and_ffs(self, remove_minors=False):
- class Head:
- def __init__(self, select, sequences, comp_depth):
- self.comp_depth = comp_depth
- self.name = str([m.name for m in sequences])
- self.sequences = sequences
- self.select = select
- seq_layers, layer_selects = self.schedule(
- 'best', remove_minors=remove_minors)
-
- all_ffs = self.get_full_seq_parents()
- if len(all_ffs) > 1:
- # filter out non-ffs in the non-trivial case
- all_ffs = [m for m in all_ffs if m.from_zipmap]
- if remove_minors:
- all_ffs = [ff for ff in all_ffs if not ff.is_minor]
-
- for i in seq_layers:
- for m in seq_layers[i]:
- if guarded_contains(all_ffs, m):
- # mark comp depths of the ffs... drawcompflow wants to know
- m.schedule_comp_depth(i)
-
- heads = []
- for i in layer_selects:
- for s in layer_selects[i]:
- seqs = [m for m in seq_layers[i] if m.analysis_parent_select == s]
- heads.append(Head(s, seqs, i))
-
- return heads, all_ffs
+ class Head:
+ def __init__(self, select, sequences, comp_depth):
+ self.comp_depth = comp_depth
+ self.name = str([m.name for m in sequences])
+ self.sequences = sequences
+ self.select = select
+ seq_layers, layer_selects = self.schedule(
+ 'best', remove_minors=remove_minors)
+
+ all_ffs = self.get_full_seq_parents()
+ if len(all_ffs) > 1:
+ # filter out non-ffs in the non-trivial case
+ all_ffs = [m for m in all_ffs if m.from_zipmap]
+ if remove_minors:
+ all_ffs = [ff for ff in all_ffs if not ff.is_minor]
+
+ for i in seq_layers:
+ for m in seq_layers[i]:
+ if guarded_contains(all_ffs, m):
+ # mark comp depths of the ffs... drawcompflow wants to know
+ m.schedule_comp_depth(i)
+
+ heads = []
+ for i in layer_selects:
+ for s in layer_selects[i]:
+ seqs = [m for m in seq_layers[i] if m.analysis_parent_select == s]
+ heads.append(Head(s, seqs, i))
+
+ return heads, all_ffs
@UnfinishedFunc
def set_display_name(self, display_name):
- self.display_name = display_name
- # again just making it more visible??? that there's an attribute being set
- # somewhere
+ self.display_name = display_name
+ # again just making it more visible??? that there's an attribute being set
+ # somewhere
@UnfinishedFunc
def make_display_names_for_all_parents(self, skip_minors=False):
- all_unfs = self.get_full_parents()
- all_seqs = [u for u in set(all_unfs) if isinstance(u, UnfinishedSequence)]
- all_selects = [u for u in set(all_unfs) if isinstance(u, UnfinishedSelect)]
- if skip_minors:
- num_orig = len(all_seqs)
- all_seqs = [seq for seq in all_seqs if not seq.is_minor]
- name_counts = Counter([m.name for m in all_seqs])
- name_suff = Counter()
- for m in sorted(all_seqs+all_selects, key=lambda u: u.creation_order_id):
- # yes, even the non-seqs need display names, albeit for now only worry
- # about repeats in the seqs and sort by creation order to get name
- # suffixes with chronological (and so non-confusing) order
- if name_counts[m.name] > 1:
- m.set_display_name(m.name+"_"+str(name_suff[m.name]))
- name_suff[m.name] += 1
-
- else:
- m.set_display_name(m.name)
+ all_unfs = self.get_full_parents()
+ all_seqs = [u for u in set(all_unfs) if isinstance(u, UnfinishedSequence)]
+ all_selects = [u for u in set(all_unfs) if isinstance(u, UnfinishedSelect)]
+ if skip_minors:
+ num_orig = len(all_seqs)
+ all_seqs = [seq for seq in all_seqs if not seq.is_minor]
+ name_counts = Counter([m.name for m in all_seqs])
+ name_suff = Counter()
+ for m in sorted(all_seqs + all_selects, key=lambda u: u.creation_order_id):
+ # yes, even the non-seqs need display names, albeit for now only worry
+ # about repeats in the seqs and sort by creation order to get name
+ # suffixes with chronological (and so non-confusing) order
+ if name_counts[m.name] > 1:
+ m.set_display_name(m.name + "_" + str(name_suff[m.name]))
+ name_suff[m.name] += 1
+
+ else:
+ m.set_display_name(m.name)
@UnfinishedFunc
def note_if_seeker(self):
- if not isinstance(self, UnfinishedSequence):
- return
+ if not isinstance(self, UnfinishedSequence):
+ return
- if (not self.get_parent_sequences()) \
- and (self.get_parent_select() is not None):
- # no parent sequences, but yes parent select: this value is a function
- # of only its parent select, i.e., a seeker (marks whether select found
- # something or not)
- self.is_seeker = True
- self.seeker_flag = self.elementwise_function()
- self.seeker_default = self._default
- else:
- self.is_seeker = False
+ if (not self.get_parent_sequences()) and \
+ (self.get_parent_select() is not None):
+ # no parent sequences, but yes parent select: this value is a function
+ # of only its parent select, i.e., a seeker (marks whether select found
+ # something or not)
+ self.is_seeker = True
+ self.seeker_flag = self.elementwise_function()
+ self.seeker_default = self._default
+ else:
+ self.is_seeker = False
@UnfinishedFunc
def mark_all_ancestor_seekers(self):
- [u.note_if_seeker() for u in self.get_full_parents()]
+ [u.note_if_seeker() for u in self.get_full_parents()]
Unfinished._full_descendants_for_seq = (None, None)
@@ -259,47 +259,47 @@ def mark_all_ancestor_seekers(self):
@UnfinishedFunc
def descendants_towards_seq(self, seq):
- if not guarded_compare(self._full_descendants_for_seq[0], seq):
+ if not guarded_compare(self._full_descendants_for_seq[0], seq):
- relevant = seq.get_full_parents()
- res = [r for r in relevant if guarded_contains(r.get_parents(), self)]
+ relevant = seq.get_full_parents()
+ res = [r for r in relevant if guarded_contains(r.get_parents(), self)]
- self._full_descendants_for_seq = (seq, res)
- return self._full_descendants_for_seq[1]
+ self._full_descendants_for_seq = (seq, res)
+ return self._full_descendants_for_seq[1]
@UnfinishedFunc
def is_minor_comp_towards_seq(self, seq):
- if not isinstance(self, UnfinishedSequence):
- return False # selects are always important
- if self.never_display: # priority: never over always
- return True
- if self.always_display:
- if self.is_constant():
- print("displaying constant:", self.name)
- return False
- if self.is_constant(): # e.g. 1 or "a" etc, just stuff created around
- # constants by REPL behind the scenes
- return True
- children = self.descendants_towards_seq(seq)
- if len(children) > 1:
- return False # this sequence was used twice -> must have been actually
- # named as a real variable in the code (and not part of some bunch of
- # operators) -> make it visible in the comp flow too
- if len(children) == 0:
- # if it's the seq itself then clearly we're very interested in it. if
- # it has no children and isnt the seq then we're checking out a weird
- # dangly unused leaf, we shouldn't reach such a scenario through any of
- # functions we'll be using to call this one, but might as well make
- # this function complete just in case we forget
- return not guarded_compare(self, seq)
- child = children[0]
- if isinstance(child, UnfinishedSelect):
- return False # this thing feeds directly into a select, lets make it
- # visible
- # obtained through zipmap and feeds directly into another zipmap: minor
- # operation as part of something more complicated
- return (child.from_zipmap and self.from_zipmap)
+ if not isinstance(self, UnfinishedSequence):
+ return False # selects are always important
+ if self.never_display: # priority: never over always
+ return True
+ if self.always_display:
+ if self.is_constant():
+ print("displaying constant:", self.name)
+ return False
+ if self.is_constant(): # e.g. 1 or "a" etc, just stuff created around
+ # constants by REPL behind the scenes
+ return True
+ children = self.descendants_towards_seq(seq)
+ if len(children) > 1:
+ return False # this sequence was used twice -> must have been actually
+ # named as a real variable in the code (and not part of some bunch of
+ # operators) -> make it visible in the comp flow too
+ if len(children) == 0:
+ # if it's the seq itself then clearly we're very interested in it. if
+ # it has no children and isnt the seq then we're checking out a weird
+ # dangly unused leaf, we shouldn't reach such a scenario through any of
+ # functions we'll be using to call this one, but might as well make
+ # this function complete just in case we forget
+ return not guarded_compare(self, seq)
+ child = children[0]
+ if isinstance(child, UnfinishedSelect):
+ return False # this thing feeds directly into a select, lets make it
+ # visible
+ # obtained through zipmap and feeds directly into another zipmap: minor
+ # operation as part of something more complicated
+ return (child.from_zipmap and self.from_zipmap)
Unfinished.is_minor = False
@@ -308,51 +308,51 @@ def is_minor_comp_towards_seq(self, seq):
@UnfinishedFunc
# another func just to be very explicit about an attribute that's getting set
def set_minor_for_seq(self, seq):
- self.is_minor = self.is_minor_comp_towards_seq(seq)
+ self.is_minor = self.is_minor_comp_towards_seq(seq)
@UnfinishedFunc
def mark_all_minor_ancestors(self):
- all_ancestors = self.get_full_parents()
- for a in all_ancestors:
- a.set_minor_for_seq(self)
+ all_ancestors = self.get_full_parents()
+ for a in all_ancestors:
+ a.set_minor_for_seq(self)
@UnfinishedFunc
def get_nonminor_parents(self): # assumes have already marked the minor
- # parents according to current interests.
- # otherwise, may remain marked according to a different seq, or possibly
- # all on default value (none are minor, all are important)
- potentials = self.get_parents()
- nonminors = []
- while potentials:
- p = potentials.pop()
- if not p.is_minor:
- nonminors.append(p)
- else:
- potentials.update(p.get_parents())
- return set(nonminors)
+ # parents according to current interests.
+ # otherwise, may remain marked according to a different seq, or possibly
+ # all on default value (none are minor, all are important)
+ potentials = self.get_parents()
+ nonminors = []
+ while potentials:
+ p = potentials.pop()
+ if not p.is_minor:
+ nonminors.append(p)
+ else:
+ potentials.update(p.get_parents())
+ return set(nonminors)
@UnfinishedFunc
def get_nonminor_parent_sequences(self):
- return [p for p in self.get_nonminor_parents()
- if isinstance(p, UnfinishedSequence)]
+ return [p for p in self.get_nonminor_parents()
+ if isinstance(p, UnfinishedSequence)]
@UnfinishedFunc
# gets both minor and nonminor sequences
def get_immediate_parent_sequences(self):
- return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)]
+ return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)]
@UnfinishedFunc
def pre_aggregate_comp(seq):
- vvars = seq.get_parent_sequences()
- vreal = zipmap(vvars, seq.elementwise_function)
- if isinstance(vreal, tuple): # equivalently, if seq.output_index >= 0:
- vreal = vreal[seq.output_index]
- return vreal
+ vvars = seq.get_parent_sequences()
+ vreal = zipmap(vvars, seq.elementwise_function)
+ if isinstance(vreal, tuple): # equivalently, if seq.output_index >= 0:
+ vreal = vreal[seq.output_index]
+ return vreal
dummyimport = None
diff --git a/RASP_support/make_operators.py b/RASP_support/make_operators.py
index a015308..f517d3f 100644
--- a/RASP_support/make_operators.py
+++ b/RASP_support/make_operators.py
@@ -3,149 +3,151 @@
# make them fully named functions instead of lambdas, even though
# it's more lines, because the debug prints are so much clearer
# this way
+
+
def add_ops(Class, apply_unary_op, apply_binary_op):
- def addsetname(f, opname, rev):
- def f_with_setname(*a):
-
- assert len(a) in [1, 2]
- if len(a) == 2:
- a0, a1 = a if not rev else (a[1], a[0])
- name0 = a0.name if hasattr(a0, "name") else str(a0)
- name1 = a1.name if hasattr(a1, "name") else str(a1)
- # a0/a1 might not be a seq, just having an op on it with a seq.
- name = name0 + " " + opname + " " + name1
- else: # len(a)==1
- name = opname + " " + a[0].name
- # probably going to be composed with more ops, so...
- name = "( " + name + " )"
- return f(*a).setname(name).allow_suppressing_display()
- # seqs created as parts of long sequences of operators may be
- # suppressed in display, the final name of the whole composition
- # will be sufficiently informative. Have to set always_display to
- # false *after* the setname, because setname marks always_display
- # as True (under assumption it is normally being called by the
- # user, who must clearly be naming some variable they care about)
- return f_with_setname
-
- def listop(f, listing_name):
- setattr(Class, listing_name, f)
-
- def addop(opname, rev=False):
- return lambda f: listop(addsetname(f, opname, rev), f.__name__)
-
- @addop("==")
- def __eq__(self, other):
- return apply_binary_op(self, other, lambda a, b: a == b)
-
- @addop("!=")
- def __ne__(self, other):
- return apply_binary_op(self, other, lambda a, b: a != b)
-
- @addop("<")
- def __lt__(self, other):
- return apply_binary_op(self, other, lambda a, b: a < b)
-
- @addop(">")
- def __gt__(self, other):
- return apply_binary_op(self, other, lambda a, b: a > b)
-
- @addop("<=")
- def __le__(self, other):
- return apply_binary_op(self, other, lambda a, b: a <= b)
-
- @addop(">=")
- def __ge__(self, other):
- return apply_binary_op(self, other, lambda a, b: a >= b)
-
- @addop("+")
- def __add__(self, other):
- return apply_binary_op(self, other, lambda a, b: a+b)
-
- @addop("+", True)
- def __radd__(self, other):
- return apply_binary_op(self, other, lambda a, b: b+a)
-
- @addop("-")
- def __sub__(self, other):
- return apply_binary_op(self, other, lambda a, b: a-b)
-
- @addop("-", True)
- def __rsub__(self, other):
- return apply_binary_op(self, other, lambda a, b: b-a)
-
- @addop("*")
- def __mul__(self, other):
- return apply_binary_op(self, other, lambda a, b: a*b)
-
- @addop("*", True)
- def __rmul__(self, other):
- return apply_binary_op(self, other, lambda a, b: b*a)
-
- @addop("//")
- def __floordiv__(self, other):
- return apply_binary_op(self, other, lambda a, b: a//b)
-
- @addop("//", True)
- def __rfloordiv__(self, other):
- return apply_binary_op(self, other, lambda a, b: b//a)
-
- @addop("/")
- def __truediv__(self, other):
- return apply_binary_op(self, other, lambda a, b: a/b)
-
- @addop("/", True)
- def __rtruediv__(self, other):
- return apply_binary_op(self, other, lambda a, b: b/a)
-
- @addop("%")
- def __mod__(self, other):
- return apply_binary_op(self, other, lambda a, b: a % b)
-
- @addop("%", True)
- def __rmod__(self, other):
- return apply_binary_op(self, other, lambda a, b: b % a)
-
- @addop("divmod")
- def __divmod__(self, other):
- return apply_binary_op(self, other, lambda a, b: divmod(a, b))
-
- @addop("divmod", True)
- def __rdivmod__(self, other):
- return apply_binary_op(self, other, lambda a, b: divmod(b, a))
-
- @addop("pow")
- def __pow__(self, other):
- return apply_binary_op(self, other, lambda a, b: pow(a, b))
-
- @addop("pow", True)
- def __rpow__(self, other):
- return apply_binary_op(self, other, lambda a, b: pow(b, a))
-
- # skipping and, or, xor, which are bitwise and dont implement 'and' and
- # 'or' but rather & and |.
- # similarly skipping lshift, rshift cause who wants them.
- # wish i had not, and, or primitives, but can accept that dont.
- # if people really want to do 'not' they can do '==False' instead, can do a
- # little macro for it in the other sugar file or whatever
-
- @addop("+")
- def __pos__(self):
- return apply_unary_op(self, lambda a: +a)
-
- @addop("-")
- def __neg__(self):
- return apply_unary_op(self, lambda a: -a)
-
- @addop("abs")
- def __abs__(self):
- return apply_unary_op(self, abs)
-
- @addop("round")
- # not sure if python will get upset if round doesnt return an actual int
- # tbh... will have to check.
- def __round__(self):
- return apply_unary_op(self, round)
-
- # defining floor, ceil, trunc showed up funny (green instead of blue),
- # gonna go ahead and avoid
+ def addsetname(f, opname, rev):
+ def f_with_setname(*a):
+
+ assert len(a) in [1, 2]
+ if len(a) == 2:
+ a0, a1 = a if not rev else (a[1], a[0])
+ name0 = a0.name if hasattr(a0, "name") else str(a0)
+ name1 = a1.name if hasattr(a1, "name") else str(a1)
+ # a0/a1 might not be a seq, just having an op on it with a seq.
+ name = name0 + " " + opname + " " + name1
+ else: # len(a)==1
+ name = opname + " " + a[0].name
+ # probably going to be composed with more ops, so...
+ name = "( " + name + " )"
+ return f(*a).setname(name).allow_suppressing_display()
+ # seqs created as parts of long sequences of operators may be
+ # suppressed in display, the final name of the whole composition
+ # will be sufficiently informative. Have to set always_display to
+ # false *after* the setname, because setname marks always_display
+ # as True (under assumption it is normally being called by the
+ # user, who must clearly be naming some variable they care about)
+ return f_with_setname
+
+ def listop(f, listing_name):
+ setattr(Class, listing_name, f)
+
+ def addop(opname, rev=False):
+ return lambda f: listop(addsetname(f, opname, rev), f.__name__)
+
+ @addop("==")
+ def __eq__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a == b)
+
+ @addop("!=")
+ def __ne__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a != b)
+
+ @addop("<")
+ def __lt__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a < b)
+
+ @addop(">")
+ def __gt__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a > b)
+
+ @addop("<=")
+ def __le__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a <= b)
+
+ @addop(">=")
+ def __ge__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a >= b)
+
+ @addop("+")
+ def __add__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a + b)
+
+ @addop("+", True)
+ def __radd__(self, other):
+ return apply_binary_op(self, other, lambda a, b: b + a)
+
+ @addop("-")
+ def __sub__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a - b)
+
+ @addop("-", True)
+ def __rsub__(self, other):
+ return apply_binary_op(self, other, lambda a, b: b - a)
+
+ @addop("*")
+ def __mul__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a * b)
+
+ @addop("*", True)
+ def __rmul__(self, other):
+ return apply_binary_op(self, other, lambda a, b: b * a)
+
+ @addop("//")
+ def __floordiv__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a // b)
+
+ @addop("//", True)
+ def __rfloordiv__(self, other):
+ return apply_binary_op(self, other, lambda a, b: b // a)
+
+ @addop("/")
+ def __truediv__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a / b)
+
+ @addop("/", True)
+ def __rtruediv__(self, other):
+ return apply_binary_op(self, other, lambda a, b: b / a)
+
+ @addop("%")
+ def __mod__(self, other):
+ return apply_binary_op(self, other, lambda a, b: a % b)
+
+ @addop("%", True)
+ def __rmod__(self, other):
+ return apply_binary_op(self, other, lambda a, b: b % a)
+
+ @addop("divmod")
+ def __divmod__(self, other):
+ return apply_binary_op(self, other, lambda a, b: divmod(a, b))
+
+ @addop("divmod", True)
+ def __rdivmod__(self, other):
+ return apply_binary_op(self, other, lambda a, b: divmod(b, a))
+
+ @addop("pow")
+ def __pow__(self, other):
+ return apply_binary_op(self, other, lambda a, b: pow(a, b))
+
+ @addop("pow", True)
+ def __rpow__(self, other):
+ return apply_binary_op(self, other, lambda a, b: pow(b, a))
+
+ # skipping and, or, xor, which are bitwise and dont implement 'and' and
+ # 'or' but rather & and |.
+ # similarly skipping lshift, rshift cause who wants them.
+ # wish i had not, and, or primitives, but can accept that dont.
+ # if people really want to do 'not' they can do '==False' instead, can do a
+ # little macro for it in the other sugar file or whatever
+
+ @addop("+")
+ def __pos__(self):
+ return apply_unary_op(self, lambda a: +a)
+
+ @addop("-")
+ def __neg__(self):
+ return apply_unary_op(self, lambda a: -a)
+
+ @addop("abs")
+ def __abs__(self):
+ return apply_unary_op(self, abs)
+
+ @addop("round")
+ # not sure if python will get upset if round doesnt return an actual int
+ # tbh... will have to check.
+ def __round__(self):
+ return apply_unary_op(self, round)
+
+ # defining floor, ceil, trunc showed up funny (green instead of blue),
+ # gonna go ahead and avoid
diff --git a/tests/make_tgts.py b/tests/make_tgts.py
index c460797..519b9cd 100644
--- a/tests/make_tgts.py
+++ b/tests/make_tgts.py
@@ -4,13 +4,13 @@
testpath = "tests"
-inpath = testpath+"/in"
-outpath = testpath+"/out"
-tgtpath = testpath+"/tgt"
-libtestspath = testpath+"/broken_libs"
-libspath = libtestspath+"/lib"
-libtgtspath = libtestspath+"/tgt"
-liboutspath = libtestspath+"/out"
+inpath = testpath + "/in"
+outpath = testpath + "/out"
+tgtpath = testpath + "/tgt"
+libtestspath = testpath + "/broken_libs"
+libspath = libtestspath + "/lib"
+libtgtspath = libtestspath + "/tgt"
+liboutspath = libtestspath + "/out"
curr_path_marker = "[current]"
@@ -20,76 +20,76 @@
def things_in_path(path):
- if not os.path.exists(path):
- return []
- return [p for p in os.listdir(path) if not p == ".DS_Store"]
+ if not os.path.exists(path):
+ return []
+ return [p for p in os.listdir(path) if not p == ".DS_Store"]
def joinpath(*a):
- return "/".join(a)
+ return "/".join(a)
for p in [tgtpath, libtgtspath]:
- if not os.path.exists(p):
- os.makedirs(p)
+ if not os.path.exists(p):
+ os.makedirs(p)
all_names = things_in_path(inpath)
def fix_file_paths(filename, curr_path_marker):
- mypath = os.path.abspath(".")
+ mypath = os.path.abspath(".")
- with open(filename, "r") as f:
- filecontents = "".join(f)
+ with open(filename, "r") as f:
+ filecontents = "".join(f)
- filecontents = filecontents.replace(mypath, curr_path_marker)
+ filecontents = filecontents.replace(mypath, curr_path_marker)
- with open(filename, "w") as f:
- print(filecontents, file=f)
+ with open(filename, "w") as f:
+ print(filecontents, file=f)
def run_input(name):
- os.system("python3 "+REPL_PATH+" <"+inpath+"/"+name+" >"+tgtpath+"/"+name)
- fix_file_paths(tgtpath+"/"+name, curr_path_marker)
+ os.system(f"python3 {REPL_PATH} <{inpath}/{name} >{tgtpath}/{name}")
+ fix_file_paths(tgtpath + "/" + name, curr_path_marker)
def run_inputs():
- print("making the target outputs!")
- for n in all_names:
- run_input(n)
+ print("making the target outputs!")
+ for n in all_names:
+ run_input(n)
def run_broken_lib(lib):
- os.system("cp "+joinpath(libspath, lib)+" "+RASPLIB_PATH)
- os.system("python3 "+REPL_PATH+" <"+joinpath(libtestspath,
- "empty.txt") + " >"+joinpath(libtgtspath, lib))
+ os.system("cp " + joinpath(libspath, lib) + " " + RASPLIB_PATH)
+ readpath = joinpath(libtestspath, "empty.txt")
+ writepath = joinpath(libtgtspath, lib)
+ os.system("python3 " + REPL_PATH + " <" + readpath + " >" + writepath)
real_rasplib_safe_place = "make_tgts_helper/temp"
safe_rasplib_name = "safe_rasplib.rasp"
+rasplib_save_loc = joinpath(real_rasplib_safe_place, safe_rasplib_name)
def save_rasplib():
- if not os.path.exists(real_rasplib_safe_place):
- os.makedirs(real_rasplib_safe_place)
- os.system("mv "+RASPLIB_PATH+" " +
- joinpath(real_rasplib_safe_place, safe_rasplib_name))
+ if not os.path.exists(real_rasplib_safe_place):
+ os.makedirs(real_rasplib_safe_place)
+ os.system("mv " + RASPLIB_PATH + " " + rasplib_save_loc)
def restore_rasplib():
- os.system("mv "+joinpath(real_rasplib_safe_place,
- safe_rasplib_name)+" "+RASPLIB_PATH)
+ os.system("mv " + rasplib_save_loc + " " + RASPLIB_PATH)
def run_broken_libs():
- print("making the broken lib targets!")
- save_rasplib()
- all_libs = things_in_path(libspath)
- for lib in all_libs:
- run_broken_lib(lib)
- restore_rasplib()
+ print("making the broken lib targets!")
+ save_rasplib()
+ all_libs = things_in_path(libspath)
+ for lib in all_libs:
+ run_broken_lib(lib)
+ restore_rasplib()
if __name__ == "__main__":
- run_inputs()
- run_broken_libs()
+ run_inputs()
+ run_broken_libs()
diff --git a/tests/test_all.py b/tests/test_all.py
index a9d9f0d..aad5649 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -1,70 +1,72 @@
import os
from make_tgts import fix_file_paths, curr_path_marker, joinpath, \
- things_in_path, inpath, outpath, tgtpath, libtestspath, libspath, \
- libtgtspath, liboutspath, save_rasplib, restore_rasplib
+ things_in_path, inpath, outpath, tgtpath, libtestspath, libspath, \
+ libtgtspath, liboutspath, save_rasplib, restore_rasplib
def check_equal(f1, f2):
- res = os.system("diff "+f1+" "+f2)
- return res == 0 # 0 = diff found no differences
+ res = os.system("diff " + f1 + " " + f2)
+ return res == 0 # 0 = diff found no differences
for p in [outpath, liboutspath]:
- if not os.path.exists(p):
- os.makedirs(p)
+ if not os.path.exists(p):
+ os.makedirs(p)
def run_input(name):
- os.system("python3 -m RASP_support <" +
- joinpath(inpath, name)+" >"+joinpath(outpath, name))
- fix_file_paths(joinpath(outpath, name), curr_path_marker)
- return check_equal(joinpath(outpath, name), joinpath(tgtpath, name))
+ readpath = joinpath(inpath, name)
+ writepath = joinpath(outpath, name)
+ os.system("python3 -m RASP_support <" + readpath + " >" + writepath)
+ fix_file_paths(writepath, curr_path_marker)
+ return check_equal(writepath, joinpath(tgtpath, name))
def run_inputs():
- all_names = things_in_path(inpath)
- passed = True
- for n in all_names:
- success = run_input(n)
- print("input", n, "passed:", success)
- if not success:
- passed = False
- return passed
+ all_names = things_in_path(inpath)
+ passed = True
+ for n in all_names:
+ success = run_input(n)
+ print("input", n, "passed:", success)
+ if not success:
+ passed = False
+ return passed
def test_broken_lib(lib):
- inlib, outlib = lib, lib.replace(".rasp", ".txt")
- os.system("cp "+joinpath(libspath, inlib)+" RASP_support/rasplib.rasp")
- os.system("python3 -m RASP_support <" + joinpath(libtestspath,
- "empty.txt") + " >" + joinpath(liboutspath, outlib))
- return check_equal(joinpath(liboutspath, outlib),
- joinpath(libtgtspath, outlib))
+ inlib, outlib = lib, lib.replace(".rasp", ".txt")
+ os.system("cp " + joinpath(libspath, inlib) + " RASP_support/rasplib.rasp")
+ readpath = joinpath(libtestspath, "empty.txt")
+ writepath = joinpath(liboutspath, outlib)
+ os.system("python3 -m RASP_support <" + readpath + " >" + writepath)
+ return check_equal(
+ joinpath(liboutspath, outlib), joinpath(libtgtspath, outlib))
def run_broken_libs():
- save_rasplib()
- all_libs = things_in_path(libspath)
- passed = True
- for lib in all_libs:
- success = test_broken_lib(lib)
- print("lib", lib, "passed (i.e., properly errored):", success)
- if not success:
- passed = False
- restore_rasplib()
- return passed
+ save_rasplib()
+ all_libs = things_in_path(libspath)
+ passed = True
+ for lib in all_libs:
+ success = test_broken_lib(lib)
+ print("lib", lib, "passed (i.e., properly errored):", success)
+ if not success:
+ passed = False
+ restore_rasplib()
+ return passed
if __name__ == "__main__":
- passed_inputs = run_inputs()
- print("passed all inputs:", passed_inputs)
- print("=====\n\n=====")
- passed_broken_libs = run_broken_libs()
- print("properly reports broken libs:", passed_broken_libs)
- print("=====\n\n=====")
+ passed_inputs = run_inputs()
+ print("passed all inputs:", passed_inputs)
+ print("=====\n\n=====")
+ passed_broken_libs = run_broken_libs()
+ print("properly reports broken libs:", passed_broken_libs)
+ print("=====\n\n=====")
- passed_everything = False not in [passed_inputs, passed_broken_libs]
- print("=====\npassed everything:", passed_everything)
- if passed_everything:
- exit(0)
- else:
- exit(1)
+ passed_everything = False not in [passed_inputs, passed_broken_libs]
+ print("=====\npassed everything:", passed_everything)
+ if passed_everything:
+ exit(0)
+ else:
+ exit(1)