From a8c73fca0c069ff8c4f4c0cf7e118d95aa6259fe Mon Sep 17 00:00:00 2001 From: Emile Ferreira <32413750+emileferreira@users.noreply.github.com> Date: Thu, 18 Jan 2024 15:15:58 +0200 Subject: [PATCH] Format Python according to PEP 8 (#10) * Format Python according to PEP 8 * Convert spaces to tabs * Ignore formatting commits * no changes to zzantlr files - these were auto generated and dont want to touch them * Disambiguate some variable names --------- Co-authored-by: Gail Weiss --- .gitignore-revs | 17 + RASP_support/DrawCompFlow.py | 524 +++++++++++++--------- RASP_support/Environment.py | 95 ++-- RASP_support/Evaluator.py | 708 +++++++++++++++++------------- RASP_support/FunctionalSupport.py | 615 ++++++++++++++++---------- RASP_support/REPL.py | 482 +++++++++++--------- RASP_support/Sugar.py | 122 ++--- RASP_support/Support.py | 239 +++++----- RASP_support/analyse.py | 290 +++++++----- RASP_support/make_operators.py | 219 ++++----- tests/make_tgts.py | 52 ++- tests/test_all.py | 54 +-- 12 files changed, 1974 insertions(+), 1443 deletions(-) create mode 100644 .gitignore-revs diff --git a/.gitignore-revs b/.gitignore-revs new file mode 100644 index 0000000..6e29267 --- /dev/null +++ b/.gitignore-revs @@ -0,0 +1,17 @@ +# Format Python according to PEP 8 +b0d40ad2fe368b2c2213c091854e2c7044fd861d +d8da41ac4a0afa9179a1971ed51a2c7fea92d155 +2c7d11e71db8eea1457755db04ddd9fe8c50e77d +16ae502f34eb4dd1d5553750c20130ee042e7475 +cff338eb239ef7dc0fe8a385e7b5f21d17068159 +ccb31632bb297572b8591b2714a5b2af460cc51f +bd3a448c727fb7fdd3ac0c9e30f3aaa5a55e539a +ad18b2ae0bcfd02bf3d4103cef478ce38e4ccd09 +d08cedcd8fad28b18b3bc199742a833eaf228bd7 +8701e153d1e6298cf8fe6d29252054caa670a4fb +2a3f43e96ccd8ce7eca7629504b637131974e3cb +4e2383ced5f0a6ab6f7b99110ac7673e9ddd2ef0 +316872e18e6101e42bf16d1134171e5fef4b7749 +31ca5b2ba5e3cc3212d7be9f870d726e972165b1 +cb664d995b0b8ff15037f4e8114313635c67d184 +309fb81e873c4de09cae62f3f611c4bec1bc6dfe diff --git a/RASP_support/DrawCompFlow.py b/RASP_support/DrawCompFlow.py index bc370ea..ff2431b 100644 --- a/RASP_support/DrawCompFlow.py +++ b/RASP_support/DrawCompFlow.py @@ -1,59 +1,68 @@ -from FunctionalSupport import UnfinishedSelect, Unfinished, UnfinishedSequence, \ - guarded_contains, guarded_compare, indices, base_tokens, tokens_asis +from FunctionalSupport import Unfinished, guarded_contains, base_tokens, \ + tokens_asis from Support import clean_val -import analyse # adds useful functions to all the unfinisheds -from analyse import UnfinishedFunc import os -from copy import copy import string +import analyse # adds useful functions to all the Unfinisheds -# fix: in ordering, we always connect bottom FF to top select. but sometimes, there is no FF (if go straight into next select), or there is no rendered select (in special case of full-select) +# fix: in ordering, we always connect bottom FF to top select. but sometimes, +# there is no FF (if go straight into next select), or there is no rendered +# select (in special case of full-select) layer_color = 'lemonchiffon' -head_color = 'bisque' #'yellow' +head_color = 'bisque' # 'yellow' indices_colour = 'bisque3' comment_colour = 'cornsilk' select_on_colour = 'plum' select_off_colour = head_color + def windows_path_cleaner(s): - if os.name == "nt": # is windows - validchars = "-_.() "+string.ascii_letters+string.digits - def fix(c): - return c if c in validchars else "." - return "".join([fix(c) for c in s]) + if os.name == "nt": # is windows + validchars = "-_.() "+string.ascii_letters+string.digits + + def fix(c): + return c if c in validchars else "." + return "".join([fix(c) for c in s]) else: return s + def colour_scheme(row_type): if row_type == INPUT: return 'gray', 'gray', 'gray' if row_type == QVAR: - return 'palegreen4','mediumseagreen', 'palegreen1' + return 'palegreen4', 'mediumseagreen', 'palegreen1' elif row_type == KVAR: - return 'deepskyblue3','darkturquoise','darkslategray1' + return 'deepskyblue3', 'darkturquoise', 'darkslategray1' elif row_type == VVAR: - return 'palevioletred3','palevioletred2','lightpink' + return 'palevioletred3', 'palevioletred2', 'lightpink' elif row_type == VREAL: - return 'plum4','plum3','thistle2' + return 'plum4', 'plum3', 'thistle2' elif row_type == RES: - return 'lightsalmon3','burlywood','burlywood1' + return 'lightsalmon3', 'burlywood', 'burlywood1' else: raise Exception("unknown row type: "+str(row_type)) -QVAR, KVAR, VVAR, VREAL, RES, INPUT = ["QVAR","KVAR","VVAR","VREAL","RES","INPUT"] -POSS_ROWS = [QVAR,KVAR,VVAR,VREAL,RES,INPUT] -ROW_NAMES = {QVAR:"Me",KVAR:"Other",VVAR:"X",VREAL:"f(X)",RES:"FF",INPUT:""} + +QVAR, KVAR, VVAR, VREAL, RES, INPUT = [ + "QVAR", "KVAR", "VVAR", "VREAL", "RES", "INPUT"] +POSS_ROWS = [QVAR, KVAR, VVAR, VREAL, RES, INPUT] +ROW_NAMES = {QVAR: "Me", KVAR: "Other", VVAR: "X", + VREAL: "f(X)", RES: "FF", INPUT: ""} + def UnfinishedFunc(f): - setattr(Unfinished,f.__name__,f) + setattr(Unfinished, f.__name__, f) + @UnfinishedFunc -def last_val(self): +def last_val(self): return self.last_res.get_vals() -def makeQKStable(qvars,kvars,select,ref_in_g): + +def makeQKStable(qvars, kvars, select, ref_in_g): qvars = [q.last_val() for q in qvars] kvars = [k.last_val() for k in kvars] select = select.last_val() @@ -63,9 +72,11 @@ def makeQKStable(qvars,kvars,select,ref_in_g): kvars_skip = len(qvars) _, _, qvars_colour = colour_scheme(QVAR) _, _, kvars_colour = colour_scheme(KVAR) - # select has qvars along the rows and kvars along the columns, so we'll do the same. - # i.e. top rows will just be the kvars and first columns will just be the qvars - # if (not qvars) and (not kvars): # no qvars or kvars -> full select -> dont waste space drawing + # select has qvars along the rows and kvars along the columns, so we'll do + # the same. i.e. top rows will just be the kvars and first columns will + # just be the qvars. + # if (not qvars) and (not kvars): + # # no qvars or kvars -> full select -> dont waste space drawing. # num_rows, num_columns = 0, 0 # pass # else: @@ -74,403 +85,478 @@ def makeQKStable(qvars,kvars,select,ref_in_g): num_rows = qvars_skip+q_val_len num_columns = kvars_skip+k_val_len - select_cells = {i:[CellVals('',head_color,j,i) for j in range(num_columns)] \ - for i in range(num_rows)} - - - for i,seq in enumerate(kvars): - for j,v in enumerate(seq): - select_cells[i][j+kvars_skip] = CellVals(v,kvars_colour,i,j+kvars_skip) - for j,seq in enumerate(qvars): - for i,v in enumerate(seq): - select_cells[i+qvars_skip][j] = CellVals(v,qvars_colour,i+qvars_skip,j) - - for i in range(num_rows-qvars_skip): # i goes over the q_var values - for j in range(num_columns-kvars_skip): # j goes over the k_var values + select_cells = {i: [CellVals('', head_color, j, i) + for j in range(num_columns)] + for i in range(num_rows)} + + for i, seq in enumerate(kvars): + for j, v in enumerate(seq): + vals = CellVals(v, kvars_colour, i, j+kvars_skip) + select_cells[i][j + kvars_skip] = vals + for j, seq in enumerate(qvars): + for i, v in enumerate(seq): + vals = CellVals(v, qvars_colour, i+qvars_skip, j) + select_cells[i + qvars_skip][j] = vals + + for i in range(num_rows-qvars_skip): # i goes over the q_var values + for j in range(num_columns-kvars_skip): # j goes over the k_var values v = select[i][j] colour = select_on_colour if v else select_off_colour - select_cells[i+qvars_skip][j+kvars_skip] = CellVals(v,colour,i+qvars_skip,j+kvars_skip,select_internal=True) + select_cells[i+qvars_skip][j+kvars_skip] = CellVals( + v, colour, i+qvars_skip, j+kvars_skip, select_internal=True) # TODO: make an ugly little q\k triangle thingy in the top corner - return GridTable(select_cells,ref_in_g) + return GridTable(select_cells, ref_in_g) + class CellVals: - def __init__(self,val,colour,i_row,i_col,select_internal=False,known_portstr=None): + def __init__(self, val, colour, i_row, i_col, select_internal=False, + known_portstr=None): def mystr(v): - if isinstance(v,bool): + if isinstance(v, bool): if select_internal: - return ' ' if v else ' ' # color gives it all! + return ' ' if v else ' ' # color gives it all! else: return 'T' if v else 'F' - if isinstance(v,float): - v = clean_val(v,3) - if isinstance(v,int) and len(str(v))==1: - v = " "+str(v) # for pretty square selectors - return str(v).replace("<","<").replace(">",">") + if isinstance(v, float): + v = clean_val(v, 3) + if isinstance(v, int) and len(str(v)) == 1: + v = " "+str(v) # for pretty square selectors + return str(v).replace("<", "<").replace(">", ">") self.val = mystr(val) self.colour = colour if None is known_portstr: self.portstr = "_col"+str(i_col)+"_row"+str(i_row) else: self.portstr = known_portstr + def __str__(self): - return ''+self.val+'' + return '' + self.val+'' class GridTable: - def __init__(self,cellvals,ref_in_g): + def __init__(self, cellvals, ref_in_g): self.ref_in_g = ref_in_g self.cellvals = cellvals - self.numcols = len(cellvals.get(0,[])) + self.numcols = len(cellvals.get(0, [])) self.numrows = len(cellvals) - self.empty = 0 in [self.numcols,self.numrows] - def to_str(self,transposed=False): + self.empty = 0 in [self.numcols, self.numrows] + + def to_str(self, transposed=False): ii = sorted(list(self.cellvals.keys())) rows = [self.cellvals[i] for i in ii] + def cells2row(cells): - return ''+''.join(map(str,cells))+'' - return '<'+''.join(map(cells2row,rows))+'
>' + return ''+''.join(map(str, cells))+'' + return '<' + ''.join(map(cells2row, rows)) \ + + '
>' + def bottom_left_portstr(self): - return self.access_portstr(0,-1) + return self.access_portstr(0, -1) + def bottom_right_portstr(self): - return self.access_portstr(-1,-1) + return self.access_portstr(-1, -1) + def top_left_portstr(self): - return self.access_portstr(0,0) + return self.access_portstr(0, 0) + def top_right_portstr(self): - return self.access_portstr(-1,0) - def top_access_portstr(self,i_col): - return self.access_portstr(i_col,0) - def bottom_access_portstr(self,i_col): - return self.access_portstr(i_col,-1) - def access_portstr(self,i_col,i_row): - return self.ref_in_g + ":" + self.internal_portstr(i_col,i_row) - def internal_portstr(self,i_col,i_row): + return self.access_portstr(-1, 0) + + def top_access_portstr(self, i_col): + return self.access_portstr(i_col, 0) + + def bottom_access_portstr(self, i_col): + return self.access_portstr(i_col, -1) + + def access_portstr(self, i_col, i_row): + return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row) + + def internal_portstr(self, i_col, i_row): if i_col < 0: i_col = self.numcols + i_col if i_row < 0: - i_row = self.numrows + i_row + i_row = self.numrows + i_row return "_col"+str(i_col)+"_row"+str(i_row) - def add_to_graph(self,g): + + def add_to_graph(self, g): if self.empty: pass else: - g.node(name=self.ref_in_g,shape='none',margin='0',label=self.to_str()) + g.node(name=self.ref_in_g, shape='none', + margin='0', label=self.to_str()) + class Table: - def __init__(self,seqs_by_rowtype,ref_in_g,rowtype_order=[]): + def __init__(self, seqs_by_rowtype, ref_in_g, rowtype_order=[]): self.ref_in_g = ref_in_g - # consistent presentation, and v useful for feedforward clarity + # consistent presentation, and v useful for feedforward clarity self.rows = [] self.seq_index = {} - if len(rowtype_order)>1: + if len(rowtype_order) > 1: self.add_rowtype_cell = True else: - assert len(seqs_by_rowtype.keys()) == 1, "table got multiple row types but no order for them" + errnote = "table got multiple row types but no order for them" + assert len(seqs_by_rowtype.keys()) == 1, errnote rowtype_order = list(seqs_by_rowtype.keys()) self.add_rowtype_cell = not (rowtype_order[0] == RES) - self.note_res_dependencies = len(seqs_by_rowtype.get(RES,[]))>1 + self.note_res_dependencies = len(seqs_by_rowtype.get(RES, [])) > 1 self.leading_metadata_offset = 1 + self.add_rowtype_cell for rt in rowtype_order: - seqs = sorted(seqs_by_rowtype[rt],key=lambda seq:seq.creation_order_id) - for i,seq in enumerate(seqs): - self.n = self.add_row(seq,rt) # each one appends to self.rows. - # self.n stores length of a single row, they will all be the same, - # just easiest to get like this - # add_row has to happen one at a time b/c they care about length of - # self.rows at time of addition (to get ports right) - self.empty = len(self.rows)==0 + seqs = sorted(seqs_by_rowtype[rt], + key=lambda seq: seq.creation_order_id) + for i, seq in enumerate(seqs): + # each one appends to self.rows. + self.n = self.add_row(seq, rt) + # self.n stores length of a single row, they will all be the + # same, just easiest to get like this + # add_row has to happen one at a time b/c they care about + # length of self.rows at time of addition (to get ports right) + self.empty = len(self.rows) == 0 if self.empty: self.n = 0 - self.transpose = False # (len(rowtype_order)==1 and rowtype_order[0]==QVAR) + # (len(rowtype_order)==1 and rowtype_order[0]==QVAR) + self.transpose = False # no need to twist Q, just making the table under anyway # transpose affects the port accesses, but think about that later + def to_str(self): rows = self.rows if not self.transpose else list(zip(*self.rows)) + def cells2row(cells): return ''+''.join(cells)+'' - return '<'+''.join(map(cells2row,rows))+'
>' + return '<' + ''.join(map(cells2row, rows)) \ + + '
>' + def bottom_left_portstr(self): - return self.access_portstr(0,-1) + return self.access_portstr(0, -1) + def bottom_right_portstr(self): - return self.access_portstr(-1,-1) + return self.access_portstr(-1, -1) + def top_left_portstr(self): - return self.access_portstr(0,0) + return self.access_portstr(0, 0) + def top_right_portstr(self): - return self.access_portstr(-1,0) - def top_access_portstr(self,i_col,skip_meta=False): - return self.access_portstr(i_col,0,skip_meta=skip_meta) - def bottom_access_portstr(self,i_col,skip_meta=False): - return self.access_portstr(i_col,-1,skip_meta=skip_meta) - def access_portstr(self,i_col,i_row,skip_meta=False): - return self.ref_in_g + ":" + self.internal_portstr(i_col,i_row,skip_meta=skip_meta) - def internal_portstr(self,i_col,i_row,skip_meta=False): - if skip_meta and (i_col >= 0): # before flip things for reverse column access + return self.access_portstr(-1, 0) + + def top_access_portstr(self, i_col, skip_meta=False): + return self.access_portstr(i_col, 0, skip_meta=skip_meta) + + def bottom_access_portstr(self, i_col, skip_meta=False): + return self.access_portstr(i_col, -1, skip_meta=skip_meta) + + def access_portstr(self, i_col, i_row, skip_meta=False): + return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row, + skip_meta=skip_meta) + + def internal_portstr(self, i_col, i_row, skip_meta=False): + if skip_meta and (i_col >= 0): # before flip things for reverse column + # access i_col += self.leading_metadata_offset if i_col < 0: i_col = (self.n) + i_col if i_row < 0: - i_row = len(self.rows) + i_row + i_row = len(self.rows) + i_row return "_col"+str(i_col)+"_row"+str(i_row) - def add_row(self,seq,row_type): - def add_cell(val,colour): - res = CellVals(val,colour,-1,-1, - known_portstr=self.internal_portstr(len(cells),len(self.rows))) + + def add_row(self, seq, row_type): + def add_cell(val, colour): + res = CellVals(val, colour, -1, -1, + known_portstr=self.internal_portstr(len(cells), + len(self.rows))) cells.append(str(res)) def add_strong_line(): # after failing to inject css styles in graphviz, - # seeing that their suggestion only creates lines (if at all? unclear) of - # width 1 (same as the border already there) and it wont make multiple VRs, + # seeing that their suggestion only creates lines + # (if at all? unclear) of width 1 + # (same as the border already there) and it wont make multiple VRs, # and realising their suggestion also does nothing, # refer to hack at the top of this priceless page: # http://jkorpela.fi/html/cellborder.html cells.append('') qkvr_colour, name_colour, data_colour = colour_scheme(row_type) - cells = [] # has to be created in advance, and not just be all the results of add_cell, - # because add_cell cares about current length of 'cells' + cells = [] # has to be created in advance, and not just be all the + # results of add_cell, because add_cell cares about current length of + # 'cells' if self.add_rowtype_cell: - add_cell(ROW_NAMES[row_type],qkvr_colour) - add_cell(seq.name,name_colour) + add_cell(ROW_NAMES[row_type], qkvr_colour) + add_cell(seq.name, name_colour) for v in seq.last_val(): - add_cell(v,data_colour) + add_cell(v, data_colour) if self.note_res_dependencies: self.seq_index[seq] = len(self.rows) add_strong_line() - add_cell("("+str(self.seq_index[seq])+")",indices_colour) - add_cell(self.dependencies_str(seq,row_type),comment_colour) + add_cell("("+str(self.seq_index[seq])+")", indices_colour) + add_cell(self.dependencies_str(seq, row_type), comment_colour) self.rows.append(cells) return len(cells) - def dependencies_str(self,seq,row_type): + def dependencies_str(self, seq, row_type): if not row_type == RES: return "" - return "from ("+", ".join(str(self.seq_index[m]) for m in seq.get_nonminor_parent_sequences()) +")" + return "from ("+", ".join(str(self.seq_index[m]) for m in + seq.get_nonminor_parent_sequences()) + ")" - def add_to_graph(self,g): + def add_to_graph(self, g): if self.empty: # g.node(name=self.ref_in_g,label="empty table") pass else: - g.node(name=self.ref_in_g,shape='none',margin='0',label=self.to_str()) + g.node(name=self.ref_in_g, shape='none', + margin='0', label=self.to_str()) -def place_above(g,node1,node2): - g.edge(node1.bottom_left_portstr(),node2.top_left_portstr(),style="invis") - g.edge(node1.bottom_right_portstr(),node2.top_right_portstr(),style="invis") +def place_above(g, node1, node2): -def connect(g,top_table,bottom_table,select_vals): + g.edge(node1.bottom_left_portstr(), node2.top_left_portstr(), + style="invis") + g.edge(node1.bottom_right_portstr(), + node2.top_right_portstr(), style="invis") + + +def connect(g, top_table, bottom_table, select_vals): # connects top_table as k and bottom_table as q if top_table.empty or bottom_table.empty: - return # not doing this for now - place_above(g,top_table,bottom_table) - # just so it positions them one on top of the other, even if select is empty + return # not doing this for now + place_above(g, top_table, bottom_table) + # just to position them one on top of the other, even if select is empty for q_i in select_vals: - for k_i,b in enumerate(select_vals[q_i]): + for k_i, b in enumerate(select_vals[q_i]): if b: # have to add 2 cause first 2 are data type and row name - g.edge(top_table.bottom_access_portstr(k_i,skip_meta=True), - bottom_table.top_access_portstr(q_i,skip_meta=True), + g.edge(top_table.bottom_access_portstr(k_i, skip_meta=True), + bottom_table.top_access_portstr(q_i, skip_meta=True), arrowhead='none') + class SubHead: - def __init__(self,name,seq): + def __init__(self, name, seq): vvars = seq.get_immediate_parent_sequences() if not seq.definitely_uses_identity_function: vreal = seq.pre_aggregate_comp() - vreal(seq.last_w) # run it on same w to fill with right results + vreal(seq.last_w) # run it on same w to fill with right results vreals = [vreal] else: vreals = [] self.name = name - self.vvars_table = Table({VVAR:vvars,VREAL:vreals},self.name+"_vvars",rowtype_order=[VVAR,VREAL]) - self.res_table = Table({RES:[seq]},self.name+"_res") - self.default = "default: "+str(seq.default) if not None is seq.default else "" - # self.vreals_table = ## ? add partly processed vals, useful for eg conditioned_contains? - - def add_to_graph(self,g): + self.vvars_table = Table( + {VVAR: vvars, VREAL: vreals}, self.name+"_vvars", + rowtype_order=[VVAR, VREAL]) + self.res_table = Table({RES: [seq]}, self.name+"_res") + self.default = "default: " + \ + str(seq.default) if seq.default is not None else "" + # self.vreals_table = ## ? add partly processed vals, useful for eg + # conditioned_contains? + + def add_to_graph(self, g): self.vvars_table.add_to_graph(g) self.res_table.add_to_graph(g) if self.default: - g.node(self.name+"_default",shape='rectangle',label=self.default) - g.edge(self.name+"_default",self.res_table.top_left_portstr(), - arrowhead='none') + g.node(self.name+"_default", shape='rectangle', label=self.default) + g.edge(self.name+"_default", self.res_table.top_left_portstr(), + arrowhead='none') - def add_edges(self,g,select_vals): - connect(g,self.vvars_table,self.res_table,select_vals) + def add_edges(self, g, select_vals): + connect(g, self.vvars_table, self.res_table, select_vals) def bottom_left_portstr(self): return self.res_table.bottom_left_portstr() + def bottom_right_portstr(self): return self.res_table.bottom_right_portstr() + def top_left_portstr(self): return self.vvars_table.top_left_portstr() + def top_right_portstr(self): return self.vvars_table.top_right_portstr() + class Head: - def __init__(self,name,head_primitives,i): + def __init__(self, name, head_primitives, i): self.name = name self.i = i self.head_primitives = head_primitives select = self.head_primitives.select q_vars, k_vars = select.q_vars, select.k_vars - q_vars = sorted(list(set(q_vars)),key=lambda a:a.creation_order_id) - k_vars = sorted(list(set(k_vars)),key=lambda a:a.creation_order_id) - self.kq_table = Table({QVAR:q_vars,KVAR:k_vars},self.name+"_qvars",rowtype_order=[KVAR,QVAR]) + q_vars = sorted(list(set(q_vars)), key=lambda a: a.creation_order_id) + k_vars = sorted(list(set(k_vars)), key=lambda a: a.creation_order_id) + self.kq_table = Table({QVAR: q_vars, KVAR: k_vars}, + self.name+"_qvars", rowtype_order=[KVAR, QVAR]) # self.k_table = Table({KVAR:k_vars},self.name+"_kvars") - self.select_result_table = makeQKStable(q_vars,k_vars,select,self.name+"_select") - # self.select_table = SelectTable(self.head_primitives.select,self.name+"_select") - self.subheads = [SubHead(self.name+"_subcomp_"+str(i),seq) for i,seq in \ - enumerate(self.head_primitives.sequences)] - - def add_to_graph(self,g): + self.select_result_table = makeQKStable( + q_vars, k_vars, select, self.name+"_select") + # self.select_table = SelectTable(self.head_primitives.select, + # self.name+"_select") + self.subheads = [SubHead(self.name+"_subcomp_"+str(i), seq) + for i, seq in + enumerate(self.head_primitives.sequences)] + + def add_to_graph(self, g): with g.subgraph(name=self.name) as head: def headlabel(): # return self.head_primitives.select.name - return 'head '+str(self.i)+\ - "\n("+self.head_primitives.select.name+")" - head.attr(fillcolor=head_color, label=headlabel(), - fontcolor='black', style='filled') + return 'head '+str(self.i) +\ + "\n("+self.head_primitives.select.name+")" + head.attr(fillcolor=head_color, label=headlabel(), + fontcolor='black', style='filled') with head.subgraph(name=self.name+"_select_parts") as sel: - sel.attr(rankdir="LR",label="",style="invis",rank="same") - if True: # not (self.kq_table.empty): + sel.attr(rankdir="LR", label="", style="invis", rank="same") + if True: # not (self.kq_table.empty): self.select_result_table.add_to_graph(sel) self.kq_table.add_to_graph(sel) # sel.edge(self.kq_table.bottom_right_portstr(), - # self.select_result_table.bottom_left_portstr(),style="invis") - + # self.select_result_table.bottom_left_portstr(),style="invis") + [s.add_to_graph(head) for s in self.subheads] - def add_organising_edges(self,g): + def add_organising_edges(self, g): if self.kq_table.empty: return for s in self.subheads: - place_above(g,self.select_result_table,s) + place_above(g, self.select_result_table, s) def bottom_left_portstr(self): return self.subheads[0].bottom_left_portstr() + def bottom_right_portstr(self): return self.subheads[-1].bottom_right_portstr() + def top_left_portstr(self): if not (self.kq_table.empty): return self.kq_table.top_left_portstr() - else: # no kq (and so no select either) table. go into subheads + else: # no kq (and so no select either) table. go into subheads return self.subheads[0].top_left_portstr() + def top_right_portstr(self): if not (self.kq_table.empty): return self.kq_table.top_right_portstr() else: return self.subheads[-1].top_right_portstr() - - def add_edges(self,g): + def add_edges(self, g): select_vals = self.head_primitives.select.last_val() # connect(g,self.k_table,self.q_table,select_vals) for s in self.subheads: - s.add_edges(g,select_vals) + s.add_edges(g, select_vals) self.add_organising_edges(g) - + + def contains_tokens(mvs): - return next((True for mv in mvs if guarded_contains(base_tokens,mv)),False) + return next((True for mv in mvs if guarded_contains(base_tokens, mv)), + False) + class Layer: - def __init__(self,depth,d_heads,d_ffs,add_tokens_on_ff=False): + def __init__(self, depth, d_heads, d_ffs, add_tokens_on_ff=False): self.heads = [] self.depth = depth self.name = self.layer_cluster_name(depth) - for i,h in enumerate(d_heads): - self.heads.append(Head(self.name+"_head"+str(i),h,i)) + for i, h in enumerate(d_heads): + self.heads.append(Head(self.name+"_head"+str(i), h, i)) ff_parents = [] for ff in d_ffs: ff_parents += ff.get_nonminor_parent_sequences() ff_parents = list(set(ff_parents)) - ff_parents = [p for p in ff_parents if not guarded_contains(d_ffs,p)] - rows_by_type = {RES:d_ffs,VVAR:ff_parents} - rowtype_order = [VVAR,RES] + ff_parents = [p for p in ff_parents if not guarded_contains(d_ffs, p)] + rows_by_type = {RES: d_ffs, VVAR: ff_parents} + rowtype_order = [VVAR, RES] if add_tokens_on_ff and not contains_tokens(ff_parents): rows_by_type[INPUT] = [tokens_asis] - rowtype_order = [INPUT] + rowtype_order - self.ff_table = Table(rows_by_type,self.name+"_ffs",rowtype_order) + rowtype_order = [INPUT] + rowtype_order + self.ff_table = Table(rows_by_type, self.name+"_ffs", rowtype_order) def bottom_object(self): if not self.ff_table.empty: return self.ff_table else: return self.heads[-1] + def top_object(self): if self.heads: return self.heads[0] else: return self.ff_table + def bottom_left_portstr(self): return self.bottom_object().bottom_left_portstr() + def bottom_right_portstr(self): return self.bottom_object().bottom_right_portstr() + def top_left_portstr(self): return self.top_object().top_left_portstr() + def top_right_portstr(self): return self.top_object().top_right_portstr() - def add_to_graph(self,g): - with g.subgraph(name=self.name) as l: - l.attr(fillcolor=layer_color, label='layer '+str(self.depth), - fontcolor='black', style='filled') + def add_to_graph(self, g): + with g.subgraph(name=self.name) as lg: + lg.attr(fillcolor=layer_color, label='layer '+str(self.depth), + fontcolor='black', style='filled') for h in self.heads: - h.add_to_graph(l) - self.ff_table.add_to_graph(l) + h.add_to_graph(lg) + self.ff_table.add_to_graph(lg) - def add_organising_edges(self,g): + def add_organising_edges(self, g): if self.ff_table.empty: return for h in self.heads: - place_above(g,h,self.ff_table) + place_above(g, h, self.ff_table) - def add_edges(self,g): + def add_edges(self, g): for h in self.heads: h.add_edges(g) self.add_organising_edges(g) - def layer_cluster_name(self,depth): - return 'cluster_l'+str(depth) # graphviz needs - # cluster names to start with 'cluster' + def layer_cluster_name(self, depth): + return 'cluster_l'+str(depth) # graphviz needs + # cluster names to start with 'cluster' + class CompFlow: - def __init__(self,all_heads,all_ffs,force_vertical_layers,add_tokens_on_ff=False): + def __init__(self, all_heads, all_ffs, force_vertical_layers, + add_tokens_on_ff=False): self.force_vertical_layers = force_vertical_layers self.add_tokens_on_ff = add_tokens_on_ff - self.make_all_layers(all_heads,all_ffs) - def make_all_layers(self,all_heads,all_ffs): + self.make_all_layers(all_heads, all_ffs) + + def make_all_layers(self, all_heads, all_ffs): self.layers = [] ff_depths = [seq.scheduled_comp_depth for seq in all_ffs] head_depths = [h.comp_depth for h in all_heads] depths = sorted(list(set(ff_depths+head_depths))) for d in depths: - d_heads = [h for h in all_heads if h.comp_depth==d] - d_heads = sorted(d_heads,key=lambda h:h.select.creation_order_id) - # only important for determinism to help debug + d_heads = [h for h in all_heads if h.comp_depth == d] + d_heads = sorted(d_heads, key=lambda h: h.select.creation_order_id) + # only important for determinism to help debug d_ffs = [f for f in all_ffs if f.scheduled_comp_depth == d] - self.layers.append(Layer(d,d_heads,d_ffs,self.add_tokens_on_ff)) + self.layers.append(Layer(d, d_heads, d_ffs, self.add_tokens_on_ff)) - def add_all_layers(self,g): - [l.add_to_graph(g) for l in self.layers] + def add_all_layers(self, g): + [layer.add_to_graph(g) for layer in self.layers] - def add_organising_edges(self,g): + def add_organising_edges(self, g): if self.force_vertical_layers: - for l1,l2 in zip(self.layers,self.layers[1:]): - place_above(g,l1,l2) + for l1, l2 in zip(self.layers, self.layers[1:]): + place_above(g, l1, l2) - def add_edges(self,g): + def add_edges(self, g): self.add_organising_edges(g) - [l.add_edges(g) for l in self.layers] + [layer.add_edges(g) for layer in self.layers] + @UnfinishedFunc -def draw_comp_flow(self,w,filename=None, - keep_dot=False,show=True, - force_vertical_layers=True, add_tokens_on_ff=False): - if not None is w: - self(w) # execute seq (and all its ancestors) on the given input w. +def draw_comp_flow(self, w, filename=None, + keep_dot=False, show=True, + force_vertical_layers=True, add_tokens_on_ff=False): + if w is not None: + self(w) # execute seq (and all its ancestors) on the given input w. # if w==None, assume seq has already been executed on some input. if not self.last_w == w: print("evaluating input failed") @@ -479,27 +565,31 @@ def draw_comp_flow(self,w,filename=None, w = self.last_w if None is filename: name = self.name - filename=os.path.join("comp_flows",windows_path_cleaner(name+"("+(str(w) if not isinstance(w,str) else "\""+w+"\"")+")")) + filename = os.path.join("comp_flows", windows_path_cleaner( + name+"("+(str(w) if not isinstance(w, str) else "\""+w+"\"")+")")) self.mark_all_minor_ancestors() self.make_display_names_for_all_parents(skip_minors=True) - - all_heads,all_ffs = self.get_all_ancestor_heads_and_ffs(remove_minors=True) - # this scheduling also marks the analysis parent selects - compflow = CompFlow(all_heads,all_ffs, - force_vertical_layers=force_vertical_layers, - add_tokens_on_ff = add_tokens_on_ff) - - # only import graphviz *inside* this function - + + all_heads, all_ffs = self.get_all_ancestor_heads_and_ffs( + remove_minors=True) + # this scheduling also marks the analysis parent selects + compflow = CompFlow(all_heads, all_ffs, + force_vertical_layers=force_vertical_layers, + add_tokens_on_ff=add_tokens_on_ff) + + # only import graphviz *inside* this function - # that way RASP can run even if graphviz setup fails # (though it will not be able to draw computation flows without it) - from graphviz import Digraph + from graphviz import Digraph g = Digraph('g') - g.attr(splines='polyline') # with curved lines it fusses over separating score edges - # and makes weirdly curved ones that start overlapping with the sequences :( + # with curved lines it fusses over separating score edges + # and makes weirdly curved ones that start overlapping with the sequences + # :( + g.attr(splines='polyline') compflow.add_all_layers(g) compflow.add_edges(g) - img_filename = g.render(filename=filename) # img_filename will end with png or something, filename is an intermediate + g.render(filename=filename) if show: g.view() if not keep_dot: - os.remove(filename) \ No newline at end of file + os.remove(filename) diff --git a/RASP_support/Environment.py b/RASP_support/Environment.py index e2c5ce3..19a1a4b 100644 --- a/RASP_support/Environment.py +++ b/RASP_support/Environment.py @@ -1,86 +1,97 @@ -from Sugar import tokens_asis, tokens_str, tokens_int, tokens_bool, tokens_float, indices -from FunctionalSupport import Unfinished, RASPTypeError +from FunctionalSupport import Unfinished, RASPTypeError, tokens_asis, \ + tokens_str, tokens_int, tokens_bool, tokens_float, indices from Evaluator import RASPFunction -from copy import deepcopy + class UndefinedVariable(Exception): - def __init__(self,varname): + def __init__(self, varname): super().__init__("Error: Undefined variable: "+varname) + class ReservedName(Exception): - def __init__(self,varname): + def __init__(self, varname): super().__init__("Error: Cannot set reserved name: "+varname) + class Environment: - def __init__(self,parent_env=None,name=None,stealing_env=None): + def __init__(self, parent_env=None, name=None, stealing_env=None): self.variables = {} self.name = name self.parent_env = parent_env self.stealing_env = stealing_env - self.base_setup() # nested envs can have them too. makes life simpler, - # instead of checking if they have the constant_variables etc in get. bit heavier on memory - # but no one's going to use this language for big nested stuff anyway + self.base_setup() # nested envs can have them too. makes life simpler, + # instead of checking if they have the constant_variables etc in get. + # bit heavier on memory but no one's going to use this language for big + # nested stuff anyway self.storing_in_constants = False def base_setup(self): - self.constant_variables = {"tokens_asis":tokens_asis, - "tokens_str":tokens_str, - "tokens_int":tokens_int, - "tokens_bool":tokens_bool, - "tokens_float":tokens_float, - "indices":indices, - "True":True, - "False":False} - self.reserved_words=["if","else","not","and","or","out","def","return","range","for","in","zip","len","get"] +\ - list(self.constant_variables.keys()) + self.constant_variables = {"tokens_asis": tokens_asis, + "tokens_str": tokens_str, + "tokens_int": tokens_int, + "tokens_bool": tokens_bool, + "tokens_float": tokens_float, + "indices": indices, + "True": True, + "False": False} + self.reserved_words = ["if", "else", "not", "and", "or", "out", "def", + "return", "range", "for", "in", "zip", "len", + "get"] + list(self.constant_variables.keys()) def snapshot(self): - res = Environment(parent_env=self.parent_env,name=self.name,stealing_env=self.stealing_env) + res = Environment(parent_env=self.parent_env, + name=self.name, stealing_env=self.stealing_env) + def carefulcopy(val): - if isinstance(val,Unfinished) or isinstance(val,RASPFunction): - return val # non mutable, at least not through rasp commands - elif isinstance(val,float) or isinstance(val,int) or isinstance(val,str) or isinstance(val,bool): - return val # non mutable - elif isinstance(val,list): + if isinstance(val, Unfinished) or isinstance(val, RASPFunction): + return val # non mutable, at least not through rasp commands + elif isinstance(val, float) or isinstance(val, int) \ + or isinstance(val, str) or isinstance(val, bool): + return val # non mutable + elif isinstance(val, list): return [carefulcopy(v) for v in val] else: - raise RASPTypeError("environment contains element that is not unfinished,", - "rasp function, float, int, string, bool, or list? :",val) - res.constant_variables = {d:carefulcopy(self.constant_variables[d]) for d in self.constant_variables} - res.variables = {d:carefulcopy(self.variables[d]) for d in self.variables} + raise RASPTypeError("environment contains element that is not " + + "unfinished, rasp function, float, int," + + "string, bool, or list? :", val) + res.constant_variables = {d: carefulcopy( + self.constant_variables[d]) for d in self.constant_variables} + res.variables = {d: carefulcopy( + self.variables[d]) for d in self.variables} return res - def make_nested(self,names_vars=[]): - res = Environment(self,name=str(self.name)+"'") - for n,v in names_vars: - res.set_variable(n,v) + def make_nested(self, names_vars=[]): + res = Environment(self, name=str(self.name)+"'") + for n, v in names_vars: + res.set_variable(n, v) return res - def get_variable(self,name): + def get_variable(self, name): if name in self.constant_variables: return self.constant_variables[name] if name in self.variables: return self.variables[name] - if not None is self.parent_env: + if self.parent_env is not None: return self.parent_env.get_variable(name) raise UndefinedVariable(name) - def _set_checked_variable(self,name,val): + def _set_checked_variable(self, name, val): if self.storing_in_constants: self.constant_variables[name] = val self.reserved_words.append(name) else: self.variables[name] = val - def set_variable(self,name,val): + def set_variable(self, name, val): if name in self.reserved_words: raise ReservedName(name) - self._set_checked_variable(name,val) - if not None is self.stealing_env: - if name.startswith("_") or name=="out": # things we don't want to steal + self._set_checked_variable(name, val) + if self.stealing_env is not None: + if name.startswith("_") or name == "out": # things we don't want + # to steal return - self.stealing_env.set_variable(name,val) + self.stealing_env.set_variable(name, val) - def set_out(self,val): + def set_out(self, val): self.variables["out"] = val diff --git a/RASP_support/Evaluator.py b/RASP_support/Evaluator.py index 8b502be..ea5cd2f 100644 --- a/RASP_support/Evaluator.py +++ b/RASP_support/Evaluator.py @@ -1,125 +1,142 @@ -from Sugar import select, zipmap, aggregate, \ - tplor, tpland, tplnot, toseq, \ - or_selects, and_selects, not_select, full_s, indices -from FunctionalSupport import Unfinished, UnfinishedSequence, UnfinishedSelect -from Support import RASPTypeError, RASPError, Select, Sequence +from FunctionalSupport import select, zipmap, aggregate, \ + or_selects, and_selects, not_select, indices, \ + Unfinished, UnfinishedSequence, UnfinishedSelect +from Sugar import tplor, tpland, tplnot, toseq, full_s +from Support import RASPTypeError, RASPError from collections.abc import Iterable -import sys from zzantlr.RASPParser import RASPParser -encoder_name = "s-op" +ENCODER_NAME = "s-op" -def strdesc(o,desc_cap=None): - if isinstance(o,Unfinished): + +def strdesc(o, desc_cap=None): + if isinstance(o, Unfinished): return o.name - if isinstance(o,list): + if isinstance(o, list): res = "["+", ".join([strdesc(v) for v in o])+"]" - if not None is desc_cap and len(res)>desc_cap: + if desc_cap is not None and len(res) > desc_cap: return "(list)" else: return res - if isinstance(o,dict): + if isinstance(o, dict): res = "{"+", ".join((strdesc(k)+": "+strdesc(o[k])) for k in o)+"}" - if not None is desc_cap and len(res)>desc_cap: + if desc_cap is not None and len(res) > desc_cap: return "(dict)" else: return res else: - if isinstance(o,str): + if isinstance(o, str): return "\""+o+"\"" else: return str(o) class RASPValueError(RASPError): - def __init__(self,*a): + def __init__(self, *a): super().__init__(*a) -debug = False -def debprint(*a,**kw): - if debug: - print(*a,**kw) +DEBUG = False + + +def debprint(*a, **kw): + if DEBUG: + print(*a, **kw) -def ast_text(ast): # just so don't have to go remembering this somewhere - # consider seeing if can make it add spaces between the tokens when doing this tho + +def ast_text(ast): # just so don't have to go remembering this somewhere + # consider seeing if can make it add spaces between the tokens when doing + # this tho return ast.getText() + def isatom(v): # the legal atoms - return True in [isinstance(v,t) for t in [int,float,str,bool]] + return True in [isinstance(v, t) for t in [int, float, str, bool]] + def name_general_type(v): - if isinstance(v,list): + if isinstance(v, list): return "list" - if isinstance(v,dict): + if isinstance(v, dict): return "dict" - if isinstance(v,UnfinishedSequence): - return encoder_name - if isinstance(v,UnfinishedSelect): + if isinstance(v, UnfinishedSequence): + return ENCODER_NAME + if isinstance(v, UnfinishedSelect): return "selector" - if isinstance(v,RASPFunction): + if isinstance(v, RASPFunction): return "function" if isatom(v): return "atom" return "??" + class ArgsError(Exception): - def __init__(self,name,expected,got): - super().__init__("wrong number of args for "+name+\ - "- expected: "+str(expected)+", got: "+str(got)+".") + def __init__(self, name, expected, got): + super().__init__("wrong number of args for "+name + + "- expected: "+str(expected)+", got: "+str(got)+".") + class NamedVal: - def __init__(self,name,val): + def __init__(self, name, val): self.name = name self.val = val + class NamedValList: - def __init__(self,namedvals): + def __init__(self, namedvals): self.nvs = namedvals + class JustVal: - def __init__(self,val): + def __init__(self, val): self.val = val + class RASPFunction: - def __init__(self,name,enclosing_env,argnames,statement_trees,returnexpr,creator_name): - self.name = name # just for debug purposes + def __init__(self, name, enclosing_env, argnames, statement_trees, + returnexpr, creator_name): + self.name = name # just for debug purposes self.enclosing_env = enclosing_env self.argnames = argnames self.statement_trees = statement_trees self.returnexpr = returnexpr self.creator = creator_name + def __str__(self): - return self.creator + " function: "+self.name+"("+", ".join(self.argnames)+")" + return self.creator + " function: " + self.name \ + + "(" + ", ".join(self.argnames) + ")" - def __call__(self,*args): + def __call__(self, *args): top_eval = args[-1] args = args[:-1] - env = self.enclosing_env.make_nested([]) # nesting, because function shouldn't affect the enclosing environment - if not len(args)==len(self.argnames): - raise ArgsError(self.name,len(self.argnames),len(args)) - for n,v in zip(self.argnames,args): - env.set_variable(n,v) - evaluator = Evaluator(env,top_eval.repl) + # nesting, because function shouldn't affect the enclosing environment + env = self.enclosing_env.make_nested([]) + if not len(args) == len(self.argnames): + raise ArgsError(self.name, len(self.argnames), len(args)) + for n, v in zip(self.argnames, args): + env.set_variable(n, v) + evaluator = Evaluator(env, top_eval.repl) for at in self.statement_trees: evaluator.evaluate(at) res = evaluator.evaluateExprsList(self.returnexpr) - return res[0] if len(res)==1 else res + return res[0] if len(res) == 1 else res + class Evaluator: - def __init__(self,env,repl): + def __init__(self, env, repl): self.env = env self.sequence_running_example = repl.sequence_running_example self.backup_example = None - # allows evaluating something that maybe doesn't necessarily work with the main running example, - # but we just want to see what happens on it - e.g. so we can do draw(tokens_int+1,[1,2]) - # without error even while the main example is still "hello" + # allows evaluating something that maybe doesn't necessarily work with + # the main running example, but we just want to see what happens on + # it - e.g. so we can do draw(tokens_int+1,[1,2]) without error even + # while the main example is still "hello" self.repl = repl - def evaluate(self,ast): + def evaluate(self, ast): if ast.expr(): - return self.evaluateExpr(ast.expr(),from_top=True) + return self.evaluateExpr(ast.expr(), from_top=True) if ast.assign(): return self.assign(ast.assign()) if ast.funcDef(): @@ -129,268 +146,304 @@ def evaluate(self,ast): if ast.forLoop(): return self.forLoop(ast.forLoop()) if ast.loadFile(): - return self.repl.loadFile(ast.loadFile(),self.env) + return self.repl.loadFile(ast.loadFile(), self.env) # more to come raise NotImplementedError - def draw(self,ast): - # TODO: make at least some rudimentary comparisons of selectors somehow to merge heads idk?????? - # maybe keep trace of operations used to create them and those with exact same parent s-ops and operations - # can get in? would still find eg select(0,0,==) and select(1,1,==) different, but its better than nothing at all - example = self.evaluateExpr(ast.inputseq) if ast.inputseq else self.sequence_running_example + def draw(self, ast): + # TODO: make at least some rudimentary comparisons of selectors somehow + # to merge heads idk?????? maybe keep trace of operations used to + # create them and those with exact same parent s-ops and operations + # can get in? would still find eg select(0,0,==) and select(1,1,==) + # different, but its better than nothing at all + example = self.evaluateExpr( + ast.inputseq) if ast.inputseq else self.sequence_running_example prev_backup = self.backup_example self.backup_example = example unf = self.evaluateExpr(ast.unf) - if not isinstance(unf,UnfinishedSequence): - raise RASPTypeError("draw expects unfinished sequence, got:",unf) + if not isinstance(unf, UnfinishedSequence): + raise RASPTypeError("draw expects unfinished sequence, got:", unf) unf.draw_comp_flow(example) res = unf(example) res.created_from_input = example self.backup_example = prev_backup return JustVal(res) - def assign(self,ast): - def set_val_and_name(val,name): - self.env.set_variable(name,val) - if isinstance(val,Unfinished): - val.setname(name) # completely irrelevant really for the REPL, + def assign(self, ast): + def set_val_and_name(val, name): + self.env.set_variable(name, val) + if isinstance(val, Unfinished): + val.setname(name) # completely irrelevant really for the REPL, # but will help maintain sanity when printing computation flows - return NamedVal(name,val) + return NamedVal(name, val) varnames = self._names_list(ast.var) values = self.evaluateExprsList(ast.val) - if len(values)==1: + if len(values) == 1: values = values[0] - if len(varnames)==1: - return set_val_and_name(values,varnames[0]) + if len(varnames) == 1: + return set_val_and_name(values, varnames[0]) else: if not len(varnames) == len(values): - raise RASPTypeError("expected",len(varnames),"values, but got:",len(values)) + raise RASPTypeError("expected", len( + varnames), "values, but got:", len(values)) reslist = [] - for v,name in zip(values,varnames): - reslist.append(set_val_and_name(v,name)) + for v, name in zip(values, varnames): + reslist.append(set_val_and_name(v, name)) return NamedValList(reslist) - def _names_list(self,ast): + def _names_list(self, ast): idsList = self._get_first_cont_list(ast) return [i.text for i in idsList] - def _set_iterator_and_vals(self,iterator_names,iterator_vals): - if len(iterator_names)==1: - self.env.set_variable(iterator_names[0],iterator_vals) - elif isinstance(iterator_vals,Iterable) and (len(iterator_vals)==len(iterator_names)): - for n,v in zip(iterator_names,iterator_vals): - self.env.set_variable(n,v) + def _set_iterator_and_vals(self, iterator_names, iterator_vals): + if len(iterator_names) == 1: + self.env.set_variable(iterator_names[0], iterator_vals) + elif isinstance(iterator_vals, Iterable) \ + and (len(iterator_vals) == len(iterator_names)): + for n, v in zip(iterator_names, iterator_vals): + self.env.set_variable(n, v) else: - if not isinstance(iterator_vals,Iterable): - raise RASPTypeError("iterating with multiple iterator names, but got single iterator value:",iterator_vals) + if not isinstance(iterator_vals, Iterable): + raise RASPTypeError( + "iterating with multiple iterator names, but got single" + + " iterator value:", iterator_vals) else: - assert not (len(iterator_vals)==len(iterator_names)), "something wrong with Evaluator logic" # should work out by logic of last failed elif - raise RASPTypeError("iterating with",len(iterator_names),"names but got",len(iterator_vals),"values (",iterator_vals,")") - - def _evaluateDictComp(self,ast): + # should work out by logic of last failed elif + errnote = "something wrong with Evaluator logic" + assert not (len(iterator_vals) == len(iterator_names)), errnote + raise RASPTypeError("iterating with", len(iterator_names), + "names but got", len(iterator_vals), + "values (", iterator_vals, ")") + + def _evaluateDictComp(self, ast): ast = ast.dictcomp d = self.evaluateExpr(ast.iterable) - if not (isinstance(d,list) or isinstance(d,dict)): - raise RASPTypeError("dict comprehension should have got a list or dict to loop over, but got:",l) + if not (isinstance(d, list) or isinstance(d, dict)): + raise RASPTypeError( + "dict comprehension should have got a list or dict to loop " + + "over, but got:", l) res = {} - iterator_names = self._names_list(ast.iterator) + iterator_names = self._names_list(ast.iterator) for vals in d: orig_env = self.env self.env = self.env.make_nested() - self._set_iterator_and_vals(iterator_names,vals) + self._set_iterator_and_vals(iterator_names, vals) key = self.make_dict_key(ast.key) res[key] = self.evaluateExpr(ast.val) self.env = orig_env return res - - def _evaluateListComp(self,ast): + def _evaluateListComp(self, ast): ast = ast.listcomp - l = self.evaluateExpr(ast.iterable) - if not (isinstance(l,list) or isinstance(l,dict)): - raise RASPTypeError("list comprehension should have got a list or dict to loop over, but got:",l) + ll = self.evaluateExpr(ast.iterable) + if not (isinstance(ll, list) or isinstance(ll, dict)): + raise RASPTypeError( + "list comprehension should have got a list or dict to loop " + + "over, but got:", ll) res = [] - iterator_names = self._names_list(ast.iterator) - for vals in l: + iterator_names = self._names_list(ast.iterator) + for vals in ll: orig_env = self.env self.env = self.env.make_nested() - self._set_iterator_and_vals(iterator_names,vals) # sets inside the now-nested env - - # don't want to keep the internal iterators after finishing this list comp + # sets inside the now-nested env -don't want to keep + # the internal iterators after finishing this list comp + self._set_iterator_and_vals(iterator_names, vals) res.append(self.evaluateExpr(ast.val)) self.env = orig_env return res - def forLoop(self,ast): + def forLoop(self, ast): iterator_names = self._names_list(ast.iterator) iterable = self.evaluateExpr(ast.iterable) - if not (isinstance(iterable,list) or isinstance(iterable,dict)): - raise RASPTypeError("for loop needs to iterate over a list or dict, but got:",iterable) + if not (isinstance(iterable, list) or isinstance(iterable, dict)): + raise RASPTypeError( + "for loop needs to iterate over a list or dict, but got:", + iterable) statements = self._get_first_cont_list(ast.mainbody) for vals in iterable: - self._set_iterator_and_vals(iterator_names,vals) + self._set_iterator_and_vals(iterator_names, vals) for s in statements: self.evaluate(s) return JustVal(None) - - def _get_first_cont_list(self,ast): + def _get_first_cont_list(self, ast): res = [] while ast: if ast.first: res.append(ast.first) # sometimes there's no first cause it's just eating a comment - ast = ast.cont + ast = ast.cont return res - def funcDef(self,ast): + def funcDef(self, ast): funcname = ast.name.text argname_trees = self._get_first_cont_list(ast.arguments) argnames = [a.text for a in argname_trees] statement_trees = self._get_first_cont_list(ast.mainbody) returnexpr = ast.retstatement.res - res = RASPFunction(funcname,self.env,argnames,statement_trees,returnexpr,self.env.name) - self.env.set_variable(funcname,res) - return NamedVal(funcname,res) + res = RASPFunction(funcname, self.env, argnames, + statement_trees, returnexpr, self.env.name) + self.env.set_variable(funcname, res) + return NamedVal(funcname, res) - def _evaluateUnaryExpr(self,ast): + def _evaluateUnaryExpr(self, ast): uexpr = self.evaluateExpr(ast.uexpr) uop = ast.uop.text - if uop =="not": - if isinstance(uexpr,UnfinishedSequence): + if uop == "not": + if isinstance(uexpr, UnfinishedSequence): return tplnot(uexpr) - elif isinstance(uexpr,UnfinishedSelect): + elif isinstance(uexpr, UnfinishedSelect): return not_select(uexpr) else: return not uexpr - if uop =="-": + if uop == "-": return -uexpr if uop == "+": return +uexpr - if uop =="round": + if uop == "round": return round(uexpr) if uop == "indicator": - if isinstance(uexpr,UnfinishedSequence): + if isinstance(uexpr, UnfinishedSequence): name = "I("+uexpr.name+")" - return zipmap(uexpr,lambda a:1 if a else 0,name=name).allow_suppressing_display() - # naming res makes RASP think it is important, i.e., - # must always be displayed. but here it has only been named for clarity, so - # correct RASP using .allow_suppressing_display() - - raise RASPTypeError("indicator operator expects "+encoder_name+", got:",uexpr) + zipmapped = zipmap(uexpr, lambda a: 1 if a else 0, name=name) + return zipmapped.allow_suppressing_display() + # naming res makes interpreter think it is important, i.e., + # must always be displayed. but here it has only been named for + # clarity, so correct it using .allow_suppressing_display() + + raise RASPTypeError( + "indicator operator expects "+ENCODER_NAME+", got:", uexpr) raise NotImplementedError - def _evaluateRange(self,ast): + def _evaluateRange(self, ast): valsList = self.evaluateExprsList(ast.rangevals) - if not len(valsList) in [1,2,3]: - raise RASPTypeError("wrong number of inputs to range, expected: 1, 2, or 3, got:",len(valsList)) + if not len(valsList) in [1, 2, 3]: + raise RASPTypeError( + "wrong number of inputs to range, expected: 1, 2, or 3, got:", + len(valsList)) for v in valsList: - if not isinstance(v,int): - raise RASPTypeError("range expects all integer inputs, but got:",strdesc(valsList)) + if not isinstance(v, int): + raise RASPTypeError( + "range expects all integer inputs, but got:", + strdesc(valsList)) return list(range(*valsList)) - def _index_into_dict(self,d,index): + def _index_into_dict(self, d, index): if not isatom(index): - raise RASPTypeError("index into dict has to be atom (i.e., string, int, float, bool), got:",strdesc(index)) + raise RASPTypeError( + "index into dict has to be atom" + + " (i.e., string, int, float, bool), got:", strdesc(index)) if index not in d: - raise RASPValueError("index [",strdesc(index),"] not in dict.") + raise RASPValueError("index [", strdesc(index), "] not in dict.") else: return d[index] - def _index_into_list_or_str(self,l,index): - lname = "list" if isinstance(l,list) else "string" - if not isinstance(index,int): - raise RASPTypeError("index into",lname,"has to be integer, got:",strdesc(index)) - if index>=len(l) or (-index)>len(l): - raise RASPValueError("index",index,"out of range for",lname,"of length",len(l)) - return l[index] - - def _index_into_sequence(self,s,index): - if isinstance(index,int): - if index>=0: - sel = select(toseq(index),indices,lambda q,k:q==k,name="load from "+str(index)) + def _index_into_list_or_str(self, ll, index): + lname = "list" if isinstance(ll, list) else "string" + if not isinstance(index, int): + raise RASPTypeError("index into", lname, + "has to be integer, got:", strdesc(index)) + if index >= len(ll) or (-index) > len(ll): + raise RASPValueError( + "index", index, "out of range for", lname, "of length", + len(ll)) + return ll[index] + + def _index_into_sequence(self, s, index): + if isinstance(index, int): + if index >= 0: + sel = select(toseq(index), indices, lambda q, + k: q == k, name="load from "+str(index)) else: length = self.env.get_variable("length") real_index = length + index real_index.setname(length.name+str(index)) - sel = select(real_index,indices,lambda q,k:q==k,name="load from "+str(index)) - return aggregate(sel,s,name=s.name+"["+str(index)+"]").allow_suppressing_display() + sel = select(real_index, indices, lambda q, + k: q == k, name="load from "+str(index)) + agg = aggregate(sel, s, name=s.name+"["+str(index)+"]") + return agg.allow_suppressing_display() else: - raise RASPValueError("index into sequence has to be integer, got:",strdesc(index)) - - - + raise RASPValueError( + "index into sequence has to be integer, got:", strdesc(index)) - def _evaluateIndexing(self,ast): + def _evaluateIndexing(self, ast): indexable = self.evaluateExpr(ast.indexable) index = self.evaluateExpr(ast.index) - - if isinstance(indexable,list) or isinstance(indexable,str): - return self._index_into_list_or_str(indexable,index) - elif isinstance(indexable,dict): - return self._index_into_dict(indexable,index) - elif isinstance(indexable,UnfinishedSequence): - return self._index_into_sequence(indexable,index) + + if isinstance(indexable, list) or isinstance(indexable, str): + return self._index_into_list_or_str(indexable, index) + elif isinstance(indexable, dict): + return self._index_into_dict(indexable, index) + elif isinstance(indexable, UnfinishedSequence): + return self._index_into_sequence(indexable, index) else: - raise RASPTypeError("can only index into a list, dict, string, or sequence, "+\ - "but instead got:",strdesc(indexable)) + raise RASPTypeError("can only index into a list, dict, string, or" + + " sequence, but instead got:", + strdesc(indexable)) - def _evaluateSelectExpr(self,ast): + def _evaluateSelectExpr(self, ast): key = self.evaluateExpr(ast.key) query = self.evaluateExpr(ast.query) sop = ast.selop.text - key = toseq(key) # in case got an atom in one of these, - query = toseq(query) # e.g. selecting 0th index: indices @= 0 - if sop=="<": - return select(query,key,lambda q,k:q>k) - if sop==">": - return select(query,key,lambda q,k:q=k) - if sop==">=": - return select(query,key,lambda q,k:q<=k) - - def _evaluateBinaryExpr(self,ast): - def has_sequence(l,r): - return isinstance(l,UnfinishedSequence) or isinstance(r,UnfinishedSequence) - def has_selector(l,r): - return isinstance(l,UnfinishedSelect) or isinstance(r,UnfinishedSelect) - def both_selectors(l,r): - return isinstance(l,UnfinishedSelect) and isinstance(r,UnfinishedSelect) + key = toseq(key) # in case got an atom in one of these, + query = toseq(query) # e.g. selecting 0th index: indices @= 0 + if sop == "<": + return select(query, key, lambda q, k: q > k) + if sop == ">": + return select(query, key, lambda q, k: q < k) + if sop == "==": + return select(query, key, lambda q, k: q == k) + if sop == "!=": + return select(query, key, lambda q, k: not (q == k)) + if sop == "<=": + return select(query, key, lambda q, k: q >= k) + if sop == ">=": + return select(query, key, lambda q, k: q <= k) + + def _evaluateBinaryExpr(self, ast): + def has_sequence(left, right): + return isinstance(left, UnfinishedSequence) \ + or isinstance(right, UnfinishedSequence) + + def has_selector(left, right): + return isinstance(left, UnfinishedSelect) \ + or isinstance(right, UnfinishedSelect) + + def both_selectors(left, right): + return isinstance(left, UnfinishedSelect) \ + and isinstance(right, UnfinishedSelect) left = self.evaluateExpr(ast.left) right = self.evaluateExpr(ast.right) bop = ast.bop.text - bad_pair = RASPTypeError("Cannot apply and/or between selector and non-selector") - if bop=="and": - if has_sequence(left,right): - if has_selector(left,right): + bad_pair = RASPTypeError( + "Cannot apply and/or between selector and non-selector") + if bop == "and": + if has_sequence(left, right): + if has_selector(left, right): raise bad_pair - return tpland(left,right) - elif has_selector(left,right): - if not both_selectors(left,right): + return tpland(left, right) + elif has_selector(left, right): + if not both_selectors(left, right): raise bad_pair - return and_selects(left,right) + return and_selects(left, right) else: return (left and right) - elif bop=="or": - if has_sequence(left,right): - if has_selector(left,right): + elif bop == "or": + if has_sequence(left, right): + if has_selector(left, right): raise bad_pair - return tplor(left,right) - elif has_selector(left,right): - if not both_selectors(left,right): + return tplor(left, right) + elif has_selector(left, right): + if not both_selectors(left, right): raise bad_pair - return or_selects(left,right) + return or_selects(left, right) else: return (left or right) - if has_selector(left,right): - raise RASPTypeError("Cannot apply",bop,"to selector(s)") + if has_selector(left, right): + raise RASPTypeError("Cannot apply", bop, "to selector(s)") elif bop == "+": return left + right elif bop == "-": @@ -399,24 +452,24 @@ def both_selectors(l,r): return left * right elif bop == "/": return left/right - elif bop=="^": - return pow(left,right) - elif bop=='%': - return left%right - elif bop=="==": - return left==right - elif bop=="<=": - return left<=right - elif bop==">=": - return left>=right - elif bop=="<": - return left": - return left>right + elif bop == "^": + return pow(left, right) + elif bop == '%': + return left % right + elif bop == "==": + return left == right + elif bop == "<=": + return left <= right + elif bop == ">=": + return left >= right + elif bop == "<": + return left < right + elif bop == ">": + return left > right # more, like modulo and power and all the other operators, to come raise NotImplementedError - def _evaluateStandalone(self,ast): + def _evaluateStandalone(self, ast): if ast.anint: return int(ast.anint.text) if ast.afloat: @@ -425,151 +478,181 @@ def _evaluateStandalone(self,ast): return ast.astring.text[1:-1] raise NotImplementedError - def _evaluateTernaryExpr(self,ast): + def _evaluateTernaryExpr(self, ast): cond = self.evaluateExpr(ast.cond) - if isinstance(cond,Unfinished): + if isinstance(cond, Unfinished): res1 = self.evaluateExpr(ast.res1) res2 = self.evaluateExpr(ast.res2) - cond, res1, res2 = tuple(map(toseq,(cond,res1,res2))) - return zipmap((cond,res1,res2),lambda c,r1,r2:r1 \ - if c else r2,name=res1.name+" if "+cond.name+" else "+res2.name).allow_suppressing_display() + cond, res1, res2 = tuple(map(toseq, (cond, res1, res2))) + return zipmap((cond, res1, res2), lambda c, r1, r2: r1 + if c else r2, name=res1.name+" if "+cond.name + + " else " + res2.name).allow_suppressing_display() else: - return self.evaluateExpr(ast.res1) if cond else self.evaluateExpr(ast.res2) - # lazy eval when cond is non-unfinished allows legal loops over actual atoms + return self.evaluateExpr(ast.res1) if cond \ + else self.evaluateExpr(ast.res2) + # lazy eval when cond is non-unfinished allows legal loops over + # actual atoms - def _evaluateAggregateExpr(self,ast): + def _evaluateAggregateExpr(self, ast): sel = self.evaluateExpr(ast.sel) seq = self.evaluateExpr(ast.seq) - seq = toseq(seq) # just in case its an atom + seq = toseq(seq) # just in case its an atom default = self.evaluateExpr(ast.default) if ast.default else None - if not isinstance(sel,UnfinishedSelect): - raise RASPTypeError("Expected selector, got:",strdesc(selector)) - if not isinstance(seq,UnfinishedSequence): - raise RASPTypeError("Expected sequence, got:",strdesc(seq)) - if isinstance(default,Unfinished): - raise RASPTypeError("Expected atom, got:",strdesc(default)) - return aggregate(sel,seq,default=default) - + if not isinstance(sel, UnfinishedSelect): + raise RASPTypeError("Expected selector, got:", strdesc(selector)) + if not isinstance(seq, UnfinishedSequence): + raise RASPTypeError("Expected sequence, got:", strdesc(seq)) + if isinstance(default, Unfinished): + raise RASPTypeError("Expected atom, got:", strdesc(default)) + return aggregate(sel, seq, default=default) - - def _evaluateZip(self,ast): + def _evaluateZip(self, ast): list_exps = self._get_first_cont_list(ast.lists) lists = [self.evaluateExpr(e) for e in list_exps] if not lists: raise RASPTypeError("zip needs at least one list") - for i,l in enumerate(lists): - if not isinstance(l,list): - raise RASPTypeError("attempting to zip lists, but",i+1,"-th element is not list:",strdesc(l)) + for i, l in enumerate(lists): + if not isinstance(l, list): + raise RASPTypeError( + "attempting to zip lists, but", i+1, + "-th element is not list:", strdesc(l)) n = len(lists[0]) - for i,l in enumerate(lists): - if not len(l)==n: - raise RASPTypeError("attempting to zip lists of length",n,", but",i+1,"-th list has length",len(l)) - return [list(v) for v in zip(*lists)] # keep everything lists, no tuples/lists mixing here, all the same to rasp (no stuff like append etc) - - def make_dict_key(self,ast): + for i, l in enumerate(lists): + if not len(l) == n: + raise RASPTypeError("attempting to zip lists of length", + n, ", but", i+1, "-th list has length", + len(l)) + # keep everything lists, no tuples/lists mixing here, all the same to + # rasp (no stuff like append etc) + return [list(v) for v in zip(*lists)] + + def make_dict_key(self, ast): res = self.evaluateExpr(ast) if not isatom(res): - raise RASPTypeError("dictionary keys can only be atoms, but instead got:",strdesc(res)) + raise RASPTypeError( + "dictionary keys can only be atoms, but instead got:", + strdesc(res)) return res - def _evaluateDict(self,ast): + def _evaluateDict(self, ast): named_exprs_list = self._get_first_cont_list(ast.dictContents) - return {self.make_dict_key(e.key):self.evaluateExpr(e.val) for e in named_exprs_list} + return {self.make_dict_key(e.key): self.evaluateExpr(e.val) + for e in named_exprs_list} - def _evaluateList(self,ast): + def _evaluateList(self, ast): exprs_list = self._get_first_cont_list(ast.listContents) return [self.evaluateExpr(e) for e in exprs_list] - def _evaluateApplication(self,ast,unf): + def _evaluateApplication(self, ast, unf): input_vals = self._get_first_cont_list(ast.inputexprs) if not len(input_vals) == 1: - raise ArgsError("evaluate unfinished",1,len(input_vals)) + raise ArgsError("evaluate unfinished", 1, len(input_vals)) input_val = self.evaluateExpr(input_vals[0]) - if not isinstance(unf,Unfinished): - raise RASPTypeError("Applying unfinished expects to apply",encoder_name,"or selector, got:",strdesc(sel)) - if not isinstance(input_val,Iterable): - raise RASPTypeError("Applying unfinished expects iterable input, got:",strdesc(input_val)) + if not isinstance(unf, Unfinished): + raise RASPTypeError("Applying unfinished expects to apply", + ENCODER_NAME, "or selector, got:", + strdesc(sel)) + if not isinstance(input_val, Iterable): + raise RASPTypeError( + "Applying unfinished expects iterable input, got:", + strdesc(input_val)) res = unf(input_val) res.created_from_input = input_val return res - def _evaluateRASPFunction(self,ast,raspfun): + def _evaluateRASPFunction(self, ast, raspfun): args_trees = self._get_first_cont_list(ast.inputexprs) args = tuple(self.evaluateExpr(t) for t in args_trees) + (self,) real_args = args[:-1] res = raspfun(*args) - if isinstance(res,Unfinished): - res.setname(raspfun.name+"("+" , ".join(strdesc(a,desc_cap=20) for a in real_args)+")") + if isinstance(res, Unfinished): + res.setname( + raspfun.name+"("+" , ".join(strdesc(a, desc_cap=20) + for a in real_args)+")") return res - - def _evaluateContains(self,ast): + def _evaluateContains(self, ast): contained = self.evaluateExpr(ast.contained) container = self.evaluateExpr(ast.container) container_name = ast.container.var.text if ast.container.var \ - else str(container) - if isinstance(contained,UnfinishedSequence): - if not isinstance(container,list): - raise RASPTypeError("\"["+encoder_name+"] in X\" expects X to be "\ - "list of atoms, but got non-list:",strdesc(container)) + else str(container) + if isinstance(contained, UnfinishedSequence): + if not isinstance(container, list): + raise RASPTypeError("\"["+ENCODER_NAME+"] in X\" expects X to" + + "be list of atoms, but got non-list:", + strdesc(container)) for v in container: if not isatom(v): - raise RASPTypeError("\"["+encoder_name+"] in X\" expects X to be "\ - "list of atoms, but got list with values:",strdesc(container)) - return zipmap(contained,lambda c:c in container, - name=contained.name+" in "+container_name).allow_suppressing_display() - elif isatom(contained): # contained is now an atom - if isinstance(container,list): + raise RASPTypeError("\"["+ENCODER_NAME+"] in X\" expects X" + + "to be list of atoms, but got list " + + "with values:", strdesc(container)) + return zipmap(contained, lambda c: c in container, + name=contained.name + " in " + + container_name).allow_suppressing_display() + elif isatom(contained): # contained is now an atom + if isinstance(container, list): return contained in container - elif isinstance(container,UnfinishedSequence): - indicator = zipmap(container,lambda v:int(v==contained)) - return aggregate(full_s,indicator)>0 - else: - raise RASPTypeError("\"[atom] in X\" expects X to be list or "+encoder_name+", but got:",strdesc(container)) - if isinstance(contained,UnfinishedSelect) or isinstance(contained,RASPFunction): - obj_name = "select" if isinstance(contained,UnfinishedSelect) else "function" - raise RASPTypeError("don't check if",obj_name, - "is contained in list/dict: unless exact same instance,", - "unable to check equivalence of",obj_name+"s") + elif isinstance(container, UnfinishedSequence): + indicator = zipmap(container, lambda v: int(v == contained)) + return aggregate(full_s, indicator) > 0 + else: + raise RASPTypeError( + "\"[atom] in X\" expects X to be list or " + ENCODER_NAME + + ", but got:", strdesc(container)) + if isinstance(contained, UnfinishedSelect) or isinstance(contained, + RASPFunction): + obj_name = "select" if isinstance( + contained, UnfinishedSelect) else "function" + raise RASPTypeError("don't check if", obj_name, + "is contained in list/dict: unless exact same " + + "instance, unable to check equivalence of", + obj_name + "s") else: - raise RASPTypeError("\"A in X\" expects A to be",encoder_name,"or atom, but got A:",strdesc(contained)) + raise RASPTypeError("\"A in X\" expects A to be", + ENCODER_NAME, "or atom, but got A:", + strdesc(contained)) - def _evaluateLen(self,ast): + def _evaluateLen(self, ast): singleList = self.evaluateExpr(ast.singleList) - if not isinstance(singleList,list) or isinstance(singleList,dict): - raise RASPTypeError("attempting to compute length of non-list:",strdesc(singleList)) + if not isinstance(singleList, list) or isinstance(singleList, dict): + raise RASPTypeError( + "attempting to compute length of non-list:", + strdesc(singleList)) return len(singleList) - def evaluateExprsList(self,ast): + def evaluateExprsList(self, ast): exprsList = self._get_first_cont_list(ast) return [self.evaluateExpr(v) for v in exprsList] - def _test_res(self,res): - if isinstance(res,Unfinished): + def _test_res(self, res): + if isinstance(res, Unfinished): def succeeds_with(exampe): try: - res(example,just_pass_exception_up=True) + res(example, just_pass_exception_up=True) except: return False else: return True succeeds_with_backup = (self.backup_example is not None) and \ - succeeds_with(self.backup_example) + succeeds_with(self.backup_example) if succeeds_with_backup: return succeeds_with_main = succeeds_with(self.sequence_running_example) if succeeds_with_main: - return - example = self.sequence_running_example if None is self.backup_example else self.backup_example - res(example,just_pass_exception_up=True) + return + example = self.sequence_running_example if self.backup_example \ + is None else self.backup_example + res(example, just_pass_exception_up=True) - def evaluateExpr(self,ast,from_top=False): - def format_return(res,resname="out",is_application_of_unfinished=False): + def evaluateExpr(self, ast, from_top=False): + def format_return(res, resname="out", + is_application_of_unfinished=False): ast.evaled_value = res - # run a quick test of the result (by attempting to evaluate it on an example) - # to make sure there hasn't been some weird type problem, so it shouts - # even before someone actively tries to evaluate it + # run a quick test of the result (by attempting to evaluate it on + # an example) to make sure there hasn't been some weird type + # problem, so it shouts even before someone actively tries to + # evaluate it self._test_res(res) if is_application_of_unfinished: @@ -577,14 +660,16 @@ def format_return(res,resname="out",is_application_of_unfinished=False): else: self.env.set_out(res) if from_top: - return NamedVal(resname, res) # this is when an expression has been evaled + # this is when an expression has been evaled + return NamedVal(resname, res) else: return res - if ast.bracketed: # in parentheses - get out of them - return self.evaluateExpr(ast.bracketed,from_top=from_top) - if ast.var: # calling single variable + if ast.bracketed: # in parentheses - get out of them + return self.evaluateExpr(ast.bracketed, from_top=from_top) + if ast.var: # calling single variable varname = ast.var.text - return format_return(self.env.get_variable(varname),resname=varname) + return format_return(self.env.get_variable(varname), + resname=varname) if ast.standalone: return format_return(self._evaluateStandalone(ast.standalone)) if ast.bop: @@ -596,7 +681,7 @@ def format_return(res,resname="out",is_application_of_unfinished=False): if ast.aggregate: return format_return(self._evaluateAggregateExpr(ast.aggregate)) if ast.unfORfun: - + # before evaluating the unfORfun expression, # consider that it may be an unf that would not work # with the current running example, and allow that it may have @@ -605,23 +690,23 @@ def format_return(res,resname="out",is_application_of_unfinished=False): input_vals = self._get_first_cont_list(ast.inputexprs) if len(input_vals) == 1: self.backup_example = self.evaluateExpr(input_vals[0]) - + unfORfun = self.evaluateExpr(ast.unfORfun) - + self.backup_example = prev_backup - - if isinstance(unfORfun,Unfinished): - return format_return(self._evaluateApplication(ast,unfORfun), + + if isinstance(unfORfun, Unfinished): + return format_return(self._evaluateApplication(ast, unfORfun), is_application_of_unfinished=True) - elif isinstance(unfORfun,RASPFunction): - return format_return(self._evaluateRASPFunction(ast,unfORfun)) + elif isinstance(unfORfun, RASPFunction): + return format_return(self._evaluateRASPFunction(ast, unfORfun)) if ast.selop: return format_return(self._evaluateSelectExpr(ast)) if ast.aList(): return format_return(self._evaluateList(ast.aList())) if ast.aDict(): return format_return(self._evaluateDict(ast.aDict())) - if ast.indexable: # indexing into a list, dict, or s-op + if ast.indexable: # indexing into a list, dict, or s-op return format_return(self._evaluateIndexing(ast)) if ast.rangevals: return format_return(self._evaluateRange(ast)) @@ -637,15 +722,14 @@ def format_return(res,resname="out",is_application_of_unfinished=False): return format_return(self._evaluateLen(ast)) raise NotImplementedError - - # new ast getText function for expressions -def new_getText(self): # original getText function stored as self._getText - if hasattr(self,"evaled_value") and isatom(self.evaled_value): +def new_getText(self): # original getText function stored as self._getText + if hasattr(self, "evaled_value") and isatom(self.evaled_value): return str(self.evaled_value) else: return self._getText() + RASPParser.ExprContext._getText = RASPParser.ExprContext.getText -RASPParser.ExprContext.getText = new_getText \ No newline at end of file +RASPParser.ExprContext.getText = new_getText diff --git a/RASP_support/FunctionalSupport.py b/RASP_support/FunctionalSupport.py index 245c2af..c287d5e 100644 --- a/RASP_support/FunctionalSupport.py +++ b/RASP_support/FunctionalSupport.py @@ -1,8 +1,9 @@ from Support import aggregate as _aggregate from Support import Sequence, RASPTypeError -from Support import select as _select +from Support import select as _select from Support import zipmap as _zipmap -import traceback, sys # for readable exception handling +import traceback +import sys # for readable exception handling from collections.abc import Iterable from copy import copy @@ -15,32 +16,42 @@ debug = False -# unique ids for all Unfinished objects, numbered by order of creation. ends up very useful sometimes +# unique ids for all Unfinished objects, numbered by order of creation. ends up +# very useful sometimes + + class NextId: def __init__(self): self.i = 0 + def get_next(self): self.i += 1 return self.i + unique_id_maker = NextId() -def creation_order_id(): - return unique_id_maker.get_next() +def creation_order_id(): + return unique_id_maker.get_next() class AlreadyPrintedTheException: def __init__(self): self.b = False + def __bool__(self): return self.b + global_printed = AlreadyPrintedTheException() # various unfinished objects + + class Unfinished: - def __init__(self,parents_tuple,parents2self,name=plain_unfinished_name,is_toplevel_input=False,min_poss_depth=-1): + def __init__(self, parents_tuple, parents2self, name=plain_unfinished_name, + is_toplevel_input=False, min_poss_depth=-1): self.parents_tuple = parents_tuple self.parents2self = parents2self self.last_w = None @@ -53,43 +64,55 @@ def __init__(self,parents_tuple,parents2self,name=plain_unfinished_name,is_tople self._full_parents = None self._sorted_full_parents = None - def setname(self,name,always_display_when_named=True): - if not None is name: - if len(name)>name_maxlen: - if isinstance(self,UnfinishedSequence): + def setname(self, name, always_display_when_named=True): + if name is not None: + if len(name) > name_maxlen: + if isinstance(self, UnfinishedSequence): name = plain_unfinished_sequence_name - elif isinstance(self,UnfinishedSelect): + elif isinstance(self, UnfinishedSelect): name = plain_unfinished_select_name else: name = plain_unfinished_name self.name = name - self.always_display = always_display_when_named # if you set something's name, you probably want to see it - return self # return self to allow chaining with other calls and throwing straight into a return statement etc + # if you set something's name, you probably want to see it + self.always_display = always_display_when_named + # return self to allow chaining with other calls and throwing straight + # into a return statement etc + return self def get_parents(self): if None is self._real_parents: - real_parents_part1 = [p for p in self.parents_tuple if is_real_unfinished(p)] - other_parents = [p for p in self.parents_tuple if not is_real_unfinished(p)] + real_parents_part1 = [ + p for p in self.parents_tuple if is_real_unfinished(p)] + other_parents = [ + p for p in self.parents_tuple if not is_real_unfinished(p)] res = real_parents_part1 - for p in other_parents: - res += p.get_parents() # recursion: branch back through all the parents of the unf, - # always stopping wherever hit something 'real' ie a select or a sequence - assert len([p for p in res if isinstance(p,UnfinishedSelect)]) <= 1 # nothing is made from more than one select... + for p in other_parents: + # recursion: branch back through all the parents of the unf, + # always stopping wherever hit something 'real' ie a select or + # a sequence + res += p.get_parents() + # nothing is made from more than one select... + assert len( + [p for p in res if isinstance(p, UnfinishedSelect)]) <= 1 self._real_parents = set(res) - return copy(self._real_parents) # in case someone messes with the list eg popping through it - + # in case someone messes with the list eg popping through it + return copy(self._real_parents) def _flat_compute_full_parents(self): - # TODO: take advantage of anywhere full_parents have already been computed, - # tho, otherwise no point in doing the recursion ever + # TODO: take advantage of anywhere full_parents have already been + # computed, tho, otherwise no point in doing the recursion ever explored = set() not_explored = set([self]) while not_explored: p = not_explored.pop() if p in explored: - continue # this may happen due to also adding things directly to explored sometimes + # this may happen due to also adding things directly to + # explored sometimes + continue if None is not p._full_parents: - explored.update(p._full_parents) # note that _full_parents include self + # note that _full_parents include self + explored.update(p._full_parents) else: new_parents = p.get_parents() explored.add(p) @@ -97,64 +120,71 @@ def _flat_compute_full_parents(self): return explored def _recursive_compute_full_parents(self): - res = self.get_parents() # get_parents returns a copy - res.update([self]) # full parents include self + res = self.get_parents() # get_parents returns a copy + res.update([self]) # full parents include self for p in self.get_parents(): - res.update(p.get_full_parents(recurse=True,trusted=True)) + res.update(p.get_full_parents(recurse=True, trusted=True)) return res def _sort_full_parents(self): if None is self._sorted_full_parents: - self._sorted_full_parents = sorted(self._full_parents,key=lambda unf:unf.creation_order_id) + self._sorted_full_parents = sorted( + self._full_parents, key=lambda unf: unf.creation_order_id) - def get_full_parents(self,recurse=False,just_compute=False,trusted=False): + def get_full_parents(self, recurse=False, just_compute=False, + trusted=False): # Note: full_parents include self if None is self._full_parents: if recurse: self._full_parents = self._recursive_compute_full_parents() else: - self._full_parents = self._flat_compute_full_parents() - # avoids recursion, and so avoids passing the max recursion depth - + self._full_parents = self._flat_compute_full_parents() + # avoids recursion, and so avoids passing the max recursion + # depth - # but now having done that we would like to store the result for all parents - # so we can take advantage of it in the future + # but now having done that we would like to store the result + # for all parents so we can take advantage of it in the future for p in self.get_sorted_full_parents(): - p.get_full_parents(recurse=True,just_compute=True) - # have them all compute their full parents so they are ready for the future - # but only do this in sorted order, so recursion is always shallow - # (always gets shorted with self._full_parents, which is being computed here - # for each unfinished starting from the top of the computation graph) + p.get_full_parents(recurse=True, just_compute=True) + # have them all compute their full parents so they are + # ready for the future, but only do this in sorted order, + # so recursion is always shallow. (always gets shorted with + # self._full_parents, which is being computed here for each + # unfinished starting from the top of the computation + # graph) if not just_compute: if trusted: - # functions where you have checked they don't modify the returned result - # can be marked as trusted and get the true _full_parents + # functions where you have checked they don't modify the + # returned result can be marked as trusted and get the true + # _full_parents return self._full_parents else: # otherwise they get a copy return copy(self._full_parents) - def get_sorted_full_parents(self): # could have just made get_full_parents give a sorted result, but - # wanted a function where name is already clear that result will be sorted, - # to avoid weird bugs in future - # (especially that being not sorted will only affect performance, and possibly break recursion depth) - + # wanted a function where name is already clear that result will be + # sorted, to avoid weird bugs in future. (especially that being not + # sorted will only affect performance, and possibly break recursion + # depth) + if None is self._sorted_full_parents: if None is self._full_parents: self.get_full_parents(just_compute=True) self._sort_full_parents() return copy(self._sorted_full_parents) - - def __call__(self,w,print_all_named_sequences=False,print_input=False, - print_all_sequences=False,print_all=False,topcall=True,just_pass_exception_up=False): - if (not isinstance(w,Iterable)) or (not w): - raise RASPTypeError("RASP sequences/selectors expect non-empty iterables, got: "+str(w)) + def __call__(self, w, print_all_named_sequences=False, print_input=False, + print_all_sequences=False, print_all=False, topcall=True, + just_pass_exception_up=False): + if (not isinstance(w, Iterable)) or (not w): + raise RASPTypeError( + "RASP sequences/selectors expect non-empty iterables, got: " + + str(w)) global_printed.b = False if w == self.last_w: - return self.last_res # don't print same calculation multiple times + return self.last_res # don't print same calculation multiple times else: if self.is_toplevel_input: @@ -163,51 +193,59 @@ def __call__(self,w,print_all_named_sequences=False,print_input=False, else: try: if topcall: - # before doing the main call, evaluate all parents - # (in order of dependencies, attainable by using creation_order_id attribute), - # this avoids a deep recursion: - # every element that is evaluated only has to go back as far as its own - # 'real' (i.e., s-op or selector) parents to hit something that has already - # been evaluated, and then those will not recurse further back as they - # use memoization + # before doing the main call, evaluate all parents + # (in order of dependencies, attainable by using + # creation_order_id attribute), this avoids a deep + # recursion: every element that is evaluated only has + # to go back as far as its own 'real' (i.e., s-op or + # selector) parents to hit something that has already + # been evaluated, and then those will not recurse + # further back as they use memoization for unf in self.get_sorted_full_parents(): - unf(w,topcall=False,just_pass_exception_up=just_pass_exception_up) # evaluate - - - res = self.parents2self(*tuple(p(w, - print_all_named_sequences=print_all_named_sequences, - print_input=print_input, - print_all_sequences=print_all_sequences, - print_all=print_all, - topcall=False, - just_pass_exception_up=just_pass_exception_up) - for p in self.parents_tuple)) + # evaluate + unf(w, topcall=False, + just_pass_exception_up=just_pass_exception_up) + + p_a_n_s = print_all_named_sequences + j_p_e_u = just_pass_exception_up + args = tuple(p(w, + print_all_named_sequences=p_a_n_s, + print_input=print_input, + print_all_sequences=print_all_sequences, + print_all=print_all, + topcall=False, + just_pass_exception_up=j_p_e_u) + for p in self.parents_tuple) + res = self.parents2self(*args) except Exception as e: if just_pass_exception_up: raise e - if isinstance(e,RASPTypeError): + if isinstance(e, RASPTypeError): raise e if not global_printed.b: - print("===============================================================") - print("===============================================================") - print("evaluation failed in: [",self.name,"] with exception:\n",e) - print("===============================================================") + seperator = "=" * 63 + print(seperator) + print(seperator) + print("evaluation failed in: [", self.name, + "] with exception:\n", e) + print(seperator) print("parent values are:") for p in self.parents_tuple: print("=============") print(p.name) print(p.last_res) - print("===============================================================") - print("===============================================================") - a,b,tb = sys.exc_info() + print(seperator) + print(seperator) + a, b, tb = sys.exc_info() tt = traceback.extract_tb(tb) - last_call = max([i for i,t in enumerate(tt) if "__call__" in str(t)]) + last_call = max([i for i, t in enumerate(tt) + if "__call__" in str(t)]) print(''.join(traceback.format_list(tt[last_call+1:]))) # traceback.print_exception(a,b,tb) global_printed.b = True - + if debug or not topcall: raise else: @@ -216,147 +254,216 @@ def __call__(self,w,print_all_named_sequences=False,print_input=False, self.last_w, self.last_res = w, res def should_print(): - if isinstance(res,Sequence): - if print_all_named_sequences and not (self.name in plain_names): + if isinstance(res, Sequence): + if print_all_named_sequences and self.name not in plain_names: return True if print_all_sequences: return True if self.is_toplevel_input and print_input: return True return print_all - if should_print(): - print("resolved \""+self.name+\ - (("\" from:\" "+str(self.get_own_root_input(w))+" \"") if print_root_inputs_too else ""),\ - ":\n\t",res) + if should_print(): + print("resolved \""+self.name + + (("\" from:\" "+str(self.get_own_root_input(w))+" \"") + if print_root_inputs_too else ""), + ":\n\t", res) return res + class UnfinishedSequence(Unfinished): - def __init__(self,parents_tuple,parents2self,name=plain_unfinished_sequence_name, - elementwise_function=None,default=None,min_poss_depth=0,from_zipmap=False, - output_index=-1,definitely_uses_identity_function=False): - # min_poss_depth=0 starts all of the base sequences (eg indices) off right - if None is name: # might have got none from some default value, fix it before continuing because later things eg DrawCompFlow - name = plain_unfinished_sequence_name # will expect name to be a string - super(UnfinishedSequence, self).__init__(parents_tuple,parents2self,name=name,min_poss_depth=min_poss_depth) - self.from_zipmap = from_zipmap # can be inferred (by seeing if there are parent selects), but this is simple enough. helpful for rendering comp flow visualisations - self.elementwise_function = elementwise_function # useful for analysis later + def __init__(self, parents_tuple, parents2self, + name=plain_unfinished_sequence_name, + elementwise_function=None, default=None, min_poss_depth=0, + from_zipmap=False, output_index=-1, + definitely_uses_identity_function=False): + # min_poss_depth=0 starts all of the base sequences (eg indices) off + # right. + + # might have got none from some default value, fix it before continuing + # because later things eg DrawCompFlow will expect name to be str + if name is None: + name = plain_unfinished_sequence_name + super(UnfinishedSequence, self).__init__(parents_tuple, + parents2self, name=name, + min_poss_depth=min_poss_depth) + # can be inferred (by seeing if there are parent selects), but this is + # simple enough. helpful for rendering comp flow visualisations + self.from_zipmap = from_zipmap + # useful for analysis later + self.elementwise_function = elementwise_function self.output_index = output_index - self.default = default # useful for analysis later - self.definitely_uses_identity_function = definitely_uses_identity_function + # useful for analysis later + self.default = default + self.definitely_uses_identity_function = \ + definitely_uses_identity_function self.never_display = False self._constant = False def __str__(self): - return "UnfinishedSequence object, name: "+self.name+" id: "+str(self.creation_order_id) + id = str(self.creation_order_id) + return "UnfinishedSequence object, name: " + self.name + " id: " + id + def mark_as_constant(self): self._constant = True return self + def is_constant(self): return self._constant class UnfinishedSelect(Unfinished): - def __init__(self,parents_tuple,parents2self, - name=plain_unfinished_select_name,compare_string=None,min_poss_depth=-1, - q_vars=None,k_vars=None,orig_selector=None): # selects should be told their depth, -1 will warn of problems properly - if None is name: # as in unfinishedsequence, some other function might have passed in a None somewhere - name = plain_unfinished_select_name # so fix before a print goes wrong - super(UnfinishedSelect, self).__init__(parents_tuple,parents2self,name=name,min_poss_depth=min_poss_depth) - self.compare_string = str(self.creation_order_id) if None is compare_string else compare_string - assert not None in [q_vars,k_vars] # they're not really optional i just dont want to add more mess to the func - self.q_vars = q_vars # don't actually need them, but useful for - self.k_vars = k_vars # drawing comp flow - # use compare string for comparison/uniqueness rather than overloading __eq__ of unfinishedselect, - # to avoid breaking things in unknown locations, and to be able to put selects in dictionaries - # and stuff (overloading __eq__ makes an object unhasheable unless i guess you overload the - # hash too?). need these comparisons for optimisations in analysis eg if two selects are identical - # they can be same head - self.orig_selector = orig_selector # for comfortable compositions of selectors + def __init__(self, parents_tuple, parents2self, + name=plain_unfinished_select_name, compare_string=None, + min_poss_depth=-1, q_vars=None, k_vars=None, + orig_selector=None): # selects should be told their depth, + # -1 will warn of problems properly + if name is None: # as in unfinishedsequence, some other function might + # have passed in a None somewhere + name = plain_unfinished_select_name # so fix before a print goes + # wrong + super(UnfinishedSelect, self).__init__(parents_tuple, + parents2self, name=name, + min_poss_depth=min_poss_depth) + self.compare_string = str( + self.creation_order_id) if compare_string is None \ + else compare_string + # they're not really optional i just dont want to add more mess to the + # func + assert None not in [q_vars, k_vars] + self.q_vars = q_vars # don't actually need them, but useful for + self.k_vars = k_vars # drawing comp flow + # use compare string for comparison/uniqueness rather than overloading + # __eq__ of unfinishedselect, to avoid breaking things in unknown + # locations, and to be able to put selects in dictionaries and stuff + # (overloading __eq__ makes an object unhasheable unless i guess you + # overload the hash too?). need these comparisons for optimisations in + # analysis eg if two selects are identical they can be same head + self.orig_selector = orig_selector # for comfortable compositions of + # selectors + def __str__(self): - return "UnfinishedSelect object, name: "+self.name+" id: "+str(self.creation_order_id) + id = str(self.creation_order_id) + return "UnfinishedSelect object, name: " + self.name + " id: " + id -def is_real_unfinished(unf): # as opposed to intermediate unfinisheds like tuples of sequences - return isinstance(unf,UnfinishedSequence) or isinstance(unf,UnfinishedSelect) +# as opposed to intermediate unfinisheds like tuples of sequences +def is_real_unfinished(unf): + return isinstance(unf, UnfinishedSequence) \ + or isinstance(unf, UnfinishedSelect) # some tiny bit of sugar that fits here: + + def is_sequence_of_unfinishedseqs(seqs): - if not isinstance(seqs,Iterable): + if not isinstance(seqs, Iterable): return False - return False not in [isinstance(seq,UnfinishedSequence) for seq in seqs] + return False not in [isinstance(seq, UnfinishedSequence) for seq in seqs] + class BareBonesFunctionalSupportException(Exception): - def __init__(self,m): - Exception.__init__(self,m) + def __init__(self, m): + Exception.__init__(self, m) + def to_tuple_of_unfinishedseqs(seqs): if is_sequence_of_unfinishedseqs(seqs): return tuple(seqs) - if isinstance(seqs,UnfinishedSequence): + if isinstance(seqs, UnfinishedSequence): return (seqs,) - print("seqs:",seqs) + print("seqs:", seqs) raise BareBonesFunctionalSupportException( - "input to select/aggregate not an unfinished sequence or sequence of unfinished sequences") + "input to select/aggregate not an unfinished sequence or sequence of" + + " unfinished sequences") + + +def tup2tup(*x): + return tuple([*x]) + -tup2tup = lambda *x:tuple([*x]) class UnfinishedSequencesTuple(Unfinished): - def __init__(self,parents_tuple,parents2self=None): - # sequence tuples only exist in here, user doesn't 'see' them. can have lots of default values - # they're just a convenience for me - if None is parents2self: # just sticking a bunch of unfinished sequences together into one thing for reasons + def __init__(self, parents_tuple, parents2self=None): + # sequence tuples only exist in here, user doesn't 'see' them. can have + # lots of default values they're just a convenience for me + if parents2self is None: # just sticking a bunch of unfinished + # sequences together into one thing for reasons parents2self = tup2tup parents_tuple = to_tuple_of_unfinishedseqs(parents_tuple) - assert is_sequence_of_unfinishedseqs(parents_tuple) and isinstance(parents_tuple,tuple) + assert is_sequence_of_unfinishedseqs( + parents_tuple) and isinstance(parents_tuple, tuple) # else - probably creating several sequences at once from one aggregate - super(UnfinishedSequencesTuple, self).__init__(parents_tuple,parents2self,name="plain unfinished tuple") - def __add__(self,other): - assert isinstance(other,UnfinishedSequencesTuple) + super(UnfinishedSequencesTuple, self).__init__( + parents_tuple, parents2self, name="plain unfinished tuple") + + def __add__(self, other): + assert isinstance(other, UnfinishedSequencesTuple) assert self.parents2self is tup2tup assert other.parents2self is tup2tup return UnfinishedSequencesTuple(self.parents_tuple+other.parents_tuple) -_input = Unfinished((),None,is_toplevel_input=True) -#### and now, the actual exposed functions -indices = UnfinishedSequence((_input,),lambda w:Sequence(list(range(len(w)))),name=plain_indices) -tokens_str = UnfinishedSequence((_input,),lambda w:Sequence(list(map(str,w))),name=plain_tokens+"_str") -tokens_int = UnfinishedSequence((_input,),lambda w:Sequence(list(map(int,w))),name=plain_tokens+"_int") -tokens_float = UnfinishedSequence((_input,),lambda w:Sequence(list(map(float,w))),name=plain_tokens+"_float") -tokens_bool = UnfinishedSequence((_input,),lambda w:Sequence(list(map(bool,w))),name=plain_tokens+"_bool") -tokens_asis = UnfinishedSequence((_input,),lambda w:Sequence(w),name=plain_tokens+"_asis") -base_tokens = [tokens_str,tokens_int,tokens_float,tokens_bool,tokens_asis] +_input = Unfinished((), None, is_toplevel_input=True) +# and now, the actual exposed functions +indices = UnfinishedSequence((_input,), lambda w: Sequence( + list(range(len(w)))), name=plain_indices) +tokens_str = UnfinishedSequence((_input,), lambda w: Sequence( + list(map(str, w))), name=plain_tokens+"_str") +tokens_int = UnfinishedSequence((_input,), lambda w: Sequence( + list(map(int, w))), name=plain_tokens+"_int") +tokens_float = UnfinishedSequence((_input,), lambda w: Sequence( + list(map(float, w))), name=plain_tokens+"_float") +tokens_bool = UnfinishedSequence((_input,), lambda w: Sequence( + list(map(bool, w))), name=plain_tokens+"_bool") +tokens_asis = UnfinishedSequence( + (_input,), lambda w: Sequence(w), name=plain_tokens+"_asis") +base_tokens = [tokens_str, tokens_int, tokens_float, tokens_bool, tokens_asis] + def _min_poss_depth(unfs): - if isinstance(unfs,Unfinished): # got single unfinished and not iterable of them + if isinstance(unfs, Unfinished): # got single unfinished and not iterable + # of them unfs = [unfs] - return max([u.min_poss_depth for u in unfs]+[0]) # max b/c cant go less deep than deepest + # max b/c cant go less deep than deepest + return max([u.min_poss_depth for u in unfs]+[0]) # add that 0 thing so list is never empty and max complains. + def tupleise(v): - if isinstance(v,tuple) or isinstance(v,list): + if isinstance(v, tuple) or isinstance(v, list): return tuple(v) return (v,) -def select(q_vars,k_vars,selector,name=None,compare_string=None): + +def select(q_vars, k_vars, selector, name=None, compare_string=None): if None is name: name = "plain select" - # potentially here check the qvars all reference the same input sequence as each other and same for the kvars, - # technically dont *have* to but is helpful for the user so consider maybe adding a tiny bit of mess here - # (including markings inside sequences and selectors so they know which index they're gathering to and from) - # to allow it - q_vars = tupleise(q_vars) # we're ok with getting a single q or k var, not in a tuple, - k_vars = tupleise(k_vars) # but important to fix it before '+' on two UnfinishedSequences + # potentially here check the qvars all reference the same input sequence as + # each other and same for the kvars, technically dont *have* to but is + # helpful for the user so consider maybe adding a tiny bit of mess here + # (including markings inside sequences and selectors so they know which + # index they're gathering to and from) to allow it + + # we're ok with getting a single q or k var, not in a tuple, + # but important to fix it before '+' on two UnfinishedSequences # (as opposed to two tuples) sends everything sideways - new_depth = _min_poss_depth(q_vars+k_vars)+1 # attn layer is one after values it needs to be calculated - res = UnfinishedSelect((_input, # need input seq length to create select of correct size - UnfinishedSequencesTuple(q_vars), - UnfinishedSequencesTuple(k_vars) ), - lambda input_seq,qv,kv: _select(len(input_seq),qv,kv,selector), - name=name,compare_string=compare_string,min_poss_depth=new_depth, - q_vars=q_vars,k_vars=k_vars,orig_selector=selector) + q_vars = tupleise(q_vars) + k_vars = tupleise(k_vars) + + # attn layer is one after values it needs to be calculated + new_depth = _min_poss_depth(q_vars+k_vars)+1 + res = UnfinishedSelect((_input, # need input seq length to create select + # of correct size + UnfinishedSequencesTuple(q_vars), + UnfinishedSequencesTuple(k_vars)), + lambda input_seq, qv, kv: _select( + len(input_seq), qv, kv, selector), + name=name, compare_string=compare_string, + min_poss_depth=new_depth, q_vars=q_vars, + k_vars=k_vars, orig_selector=selector) return res -def _compose_selects(select1,select2,compose_op=None,name=None,compare_string=None): + +def _compose_selects(select1, select2, compose_op=None, name=None, + compare_string=None): nq1 = len(select1.q_vars) nq2 = len(select2.q_vars)+nq1 nk1 = len(select1.k_vars)+nq2 @@ -366,121 +473,155 @@ def new_selector(*qqkk): q2 = qqkk[nq1:nq2] k1 = qqkk[nq2:nk1] k2 = qqkk[nk1:] - return compose_op(select1.orig_selector(*q1,*k1), select2.orig_selector(*q2,*k2)) + return compose_op(select1.orig_selector(*q1, *k1), + select2.orig_selector(*q2, *k2)) return select(select1.q_vars+select2.q_vars, select1.k_vars+select2.k_vars, - new_selector,name=name,compare_string=compare_string) + new_selector, name=name, compare_string=compare_string) + -def _compose_select(select1,compose_op=None,name=None,compare_string=None): +def _compose_select(select1, compose_op=None, name=None, compare_string=None): def new_selector(*qk): return compose_op(select1.orig_selector(*qk)) return select(select1.q_vars, select1.k_vars, - new_selector,name=name,compare_string=compare_string) + new_selector, name=name, compare_string=compare_string) + + +def not_select(select, name=None, compare_string=None): + return _compose_select(select, lambda a: not a, + name=name, compare_string=compare_string) + -def not_select(select,name=None,compare_string=None): - return _compose_select(select,lambda a:not a, - name=name,compare_string=compare_string) +def and_selects(select1, select2, name=None, compare_string=None): + return _compose_selects(select1, select2, lambda a, b: a and b, + name=name, compare_string=compare_string) -def and_selects(select1,select2,name=None,compare_string=None): - return _compose_selects(select1,select2,lambda a,b:a and b, - name=name,compare_string=compare_string) -def or_selects(select1,select2,name=None,compare_string=None): - return _compose_selects(select1,select2,lambda a,b:a or b, - name=name,compare_string=compare_string) +def or_selects(select1, select2, name=None, compare_string=None): + return _compose_selects(select1, select2, lambda a, b: a or b, + name=name, compare_string=compare_string) -def format_output(parents_tuple,parents2res,name,elementwise_function=None, - default=None,min_poss_depth=0,from_zipmap=False, - definitely_uses_identity_function=False): - return UnfinishedSequence(parents_tuple,parents2res, - elementwise_function=elementwise_function,default=default, - name=name,min_poss_depth=min_poss_depth,from_zipmap=from_zipmap, - definitely_uses_identity_function=definitely_uses_identity_function) +def format_output(parents_tuple, parents2res, name, elementwise_function=None, + default=None, min_poss_depth=0, from_zipmap=False, + definitely_uses_identity_function=False): + def_uses = definitely_uses_identity_function + return UnfinishedSequence(parents_tuple, parents2res, + elementwise_function=elementwise_function, + default=default, name=name, + min_poss_depth=min_poss_depth, + from_zipmap=from_zipmap, + definitely_uses_identity_function=def_uses) def get_identity_function(num_params): def identity1(a): return a + def identityx(*a): return a - return identity1 if num_params==1 else identityx + return identity1 if num_params == 1 else identityx -def zipmap(sequences_tuple,elementwise_function,name=plain_unfinished_sequence_name): +def zipmap(sequences_tuple, elementwise_function, + name=plain_unfinished_sequence_name): sequences_tuple = tupleise(sequences_tuple) - unfinished_parents_tuple = UnfinishedSequencesTuple(sequences_tuple) # this also takes care of turning the - # value in sequences_tuple to indeed a tuple of sequences and not eg a single sequence which will - # cause weird behaviour later - - parents_tuple = (_input,unfinished_parents_tuple) - parents2res = lambda w,vt: _zipmap(len(w),vt,elementwise_function) - min_poss_depth = _min_poss_depth(sequences_tuple) # feedforward doesn't increase layer - # new assumption, to be revised later: can do arbitrary zipmap even before first feed-forward, - # i.e. in build up to first attention. truth is can do 'simple' zipmap towards first attention - # (no xor, but yes things like 'and' or 'indicator for ==' or whatever) based on initial linear - # translation done for Q,K in attention (not deep enough for xor, but deep enough for simple stuff) - # alongside use of initial embedding. honestly literally can just put everything in initial embedding - # if need it so bad its the first layer and its zipmap its only a function of the token and indices, - # so long as its not computing any weird combination between them you can do it in the embedding - # if len(sequences_tuple)>0: - # min_poss_depth = max(min_poss_depth,1) # except for the very specific case where - # # it is the very first thing to be done, in which case we do have to go through - # # one layer to get to the first feedforward. - # # the 'if' is there to rule out increasing when doing a feedforward on nothing, - # # ie, when making a constant. constants are allowed to be created on layer 0, they're - # # part of the embedding or the weights that will use them later or whatever, it's fine - return format_output(parents_tuple,parents2res,name, - min_poss_depth=min_poss_depth,elementwise_function=elementwise_function, - from_zipmap=True) # at least as deep as needed MVs, but no - # deeper cause FF (which happens at end of layer) - -def aggregate(select,sequences_tuple,elementwise_function=None, - default=None,name=plain_unfinished_sequence_name): + unfinished_parents_tuple = UnfinishedSequencesTuple( + sequences_tuple) # this also takes care of turning the + # value in sequences_tuple to indeed a tuple of sequences and not eg a + # single sequence which will cause weird behaviour later + + parents_tuple = (_input, unfinished_parents_tuple) + def parents2res(w, vt): return _zipmap(len(w), vt, elementwise_function) + # feedforward doesn't increase layer + min_poss_depth = _min_poss_depth(sequences_tuple) + # new assumption, to be revised later: can do arbitrary zipmap even before + # first feed-forward, i.e. in build up to first attention. truth is can do + # 'simple' zipmap towards first attention (no xor, but yes things like + # 'and' or 'indicator for ==' or whatever) based on initial linear + # translation done for Q,K in attention (not deep enough for xor, but deep + # enough for simple stuff) alongside use of initial embedding. honestly + # literally can just put everything in initial embedding if need it so bad + # its the first layer and its zipmap its only a function of the token and + # indices, so long as its not computing any weird combination between them + # you can do it in the embedding + # if len(sequences_tuple)>0: + # min_poss_depth = max(min_poss_depth,1) # except for the very specific + # # case where it is the very first thing to be done, in which case we do + # # have to go through one layer to get to the first feedforward. + # # the 'if' is there to rule out increasing when doing a feedforward on + # # nothing, ie, when making a constant. constants are allowed to be + # # created on layer 0, they're part of the embedding or the weights that + # # will use them later or whatever, it's fine + + # at least as deep as needed MVs, but no deeper cause FF + # (which happens at end of layer) + return format_output(parents_tuple, parents2res, name, + min_poss_depth=min_poss_depth, + elementwise_function=elementwise_function, + from_zipmap=True) + + +def aggregate(select, sequences_tuple, elementwise_function=None, + default=None, name=plain_unfinished_sequence_name): sequences_tuple = tupleise(sequences_tuple) - definitely_uses_identity_function = None is elementwise_function + definitely_uses_identity_function = None is elementwise_function if definitely_uses_identity_function: elementwise_function = get_identity_function(len(sequences_tuple)) unfinished_parents_tuple = UnfinishedSequencesTuple(sequences_tuple) - parents_tuple = (select,unfinished_parents_tuple) - parents2res = lambda s,vt:_aggregate(s,vt,elementwise_function,default=default) - return format_output(parents_tuple,parents2res,name, - elementwise_function=elementwise_function,default=default, - min_poss_depth=max(_min_poss_depth(sequences_tuple)+1,select.min_poss_depth), - definitely_uses_identity_function=definitely_uses_identity_function) - # at least as deep as needed attention and at least one deeper than needed MVs + parents_tuple = (select, unfinished_parents_tuple) + def parents2res(s, vt): return _aggregate( + s, vt, elementwise_function, default=default) + def_uses = definitely_uses_identity_function + + # at least as deep as needed attention and at least one deeper than needed + # MVs + return format_output(parents_tuple, parents2res, name, + elementwise_function=elementwise_function, + default=default, + min_poss_depth=max(_min_poss_depth( + sequences_tuple)+1, select.min_poss_depth), + definitely_uses_identity_function=def_uses) + + +# up to here was just plain transformer 'assembly'. any addition is a lie +# now begin the bells and whistles -########### up to here was just plain transformer 'assembly'. any addition is a lie ############## -##################### now begin the bells and whistles ########################################### def UnfinishedSequenceFunc(f): - setattr(UnfinishedSequence,f.__name__,f) + setattr(UnfinishedSequence, f.__name__, f) -def UnfinishedFunc(f): - setattr(Unfinished,f.__name__,f) +def UnfinishedFunc(f): + setattr(Unfinished, f.__name__, f) @UnfinishedSequenceFunc def allow_suppressing_display(self): self.always_display = False - return self # return self to allow chaining with other calls and throwing straight into a return statement etc + return self # return self to allow chaining with other calls and throwing + # straight into a return statement etc -# later, we will overload == for unfinished sequences, such that it always returns another -# unfinished sequence. unfortunately this creates the following upsetting behaviour: -# "a in l" and "a==b" always evaluates to true for any unfinishedsequences a,b and non-empty -# list l, and any item a and list l containing at least one unfinished sequence. hence, to -# check if a sequence is really in a list we have to do it ourselves, some other way. +# later, we will overload == for unfinished sequences, such that it always +# returns another unfinished sequence. unfortunately this creates the following +# upsetting behaviour: +# "a in l" and "a==b" always evaluates to true for any unfinishedsequences a,b +# and non-empty list l, and any item a and list l containing at least one +# unfinished sequence. hence, to check if a sequence is really in a list we +# have to do it ourselves, some other way. -def guarded_compare(seq1,seq2): - if isinstance(seq1,UnfinishedSequence) or isinstance(seq2,UnfinishedSequence): + +def guarded_compare(seq1, seq2): + if isinstance(seq1, UnfinishedSequence) \ + or isinstance(seq2, UnfinishedSequence): return seq1 is seq2 return seq1 == seq2 -def guarded_contains(l,a): - if isinstance(a,Unfinished): - return True in [(a is e) for e in l] + +def guarded_contains(ll, a): + if isinstance(a, Unfinished): + return True in [(a is e) for e in ll] else: - l = [e for e in l if not isinstance(e,Unfinished)] - return a in l + ll = [e for e in ll if not isinstance(e, Unfinished)] + return a in ll diff --git a/RASP_support/REPL.py b/RASP_support/REPL.py index ff493e4..9bff3a7 100644 --- a/RASP_support/REPL.py +++ b/RASP_support/REPL.py @@ -1,56 +1,65 @@ +from antlr4.error.ErrorListener import ErrorListener from antlr4 import CommonTokenStream, InputStream from collections.abc import Iterable - from zzantlr.RASPLexer import RASPLexer from zzantlr.RASPParser import RASPParser -from zzantlr.RASPVisitor import RASPVisitor - from Environment import Environment, UndefinedVariable, ReservedName from FunctionalSupport import UnfinishedSequence, UnfinishedSelect, Unfinished from Evaluator import Evaluator, NamedVal, NamedValList, JustVal, \ - RASPFunction, ArgsError, RASPTypeError, RASPValueError + RASPFunction, ArgsError, RASPTypeError, RASPValueError from Support import Select, Sequence, lazy_type_check -encoder_name = "s-op" +ENCODER_NAME = "s-op" + class ResultToPrint: - def __init__(self,res,to_print): + def __init__(self, res, to_print): self.res, self.print = res, to_print + class LazyPrint: - def __init__(self,*a,**kw): + def __init__(self, *a, **kw): self.a, self.kw = a, kw + def print(self): - print(*self.a,**self.kw) + print(*self.a, **self.kw) + class StopException(Exception): def __init__(self): super().__init__() -debug = False -def debprint(*a,**kw): - if debug: - print(*a,**kw) +DEBUG = False + + +def debprint(*a, **kw): + if DEBUG: + print(*a, **kw) + class ReturnExample: - def __init__(self,subset): + def __init__(self, subset): self.subset = subset + class LoadError(Exception): - def __init__(self,msg): + def __init__(self, msg): super().__init__(msg) + def is_comment(line): - if not isinstance(line,str): + if not isinstance(line, str): return False return line.strip().startswith("#") + def formatstr(res): - if isinstance(res,str): + if isinstance(res, str): return "\""+res+"\"" return str(res) + class REPL: def __init__(self): self.env = Environment(name="console") @@ -65,307 +74,339 @@ def __init__(self): def load_base_libraries_and_make_base_env(self): self.silent = True - self.base_env = self.env.snapshot() # base env: the env from which every load begins - # bootstrap base_env with current (basically empty except indices etc) env, then load - # the base libraries to build the actual base env - self.env.storing_in_constants = True # make the library-loaded variables and functions not-overwriteable - for l in ["RASP_support/rasplib"]: - self.run_given_line("load \""+l+"\";") + # base env: the env from which every load begins + self.base_env = self.env.snapshot() + # bootstrap base_env with current (basically empty except indices etc) + # env, then load the base libraries to build the actual base env + # make the library-loaded variables and functions not-overwriteable + self.env.storing_in_constants = True + for lib in ["RASP_support/rasplib"]: + self.run_given_line("load \"" + lib + "\";") self.base_env = self.env.snapshot() self.env.storing_in_constants = False self.run_given_line("tokens=tokens_str;") self.base_env = self.env.snapshot() self.silent = False - - def set_running_example(self,example,which="both"): - if which in ["both",encoder_name]: + def set_running_example(self, example, which="both"): + if which in ["both", ENCODER_NAME]: self.sequence_running_example = example - if which in ["both","selector"]: + if which in ["both", "selector"]: self.selector_running_example = example def print_welcome(self): print("RASP 0.0") - print("running example is:",self.sequence_running_example) + print("running example is:", self.sequence_running_example) - def print_just_val(self,justval): + def print_just_val(self, justval): val = justval.val if None is val: return - if isinstance(val,Select): + if isinstance(val, Select): print("\t = ") - print_select(val.created_from_input,val) - elif isinstance(val,Sequence) and self.sequence_prints_verbose: - print("\t = ",end="") - print_seq(val.created_from_input,val,still_on_prev_line=True) + print_select(val.created_from_input, val) + elif isinstance(val, Sequence) and self.sequence_prints_verbose: + print("\t = ", end="") + print_seq(val.created_from_input, val, still_on_prev_line=True) else: - print("\t = ",str(val).replace("\n","\n\t\t\t")) + print("\t = ", str(val).replace("\n", "\n\t\t\t")) - def print_named_val(self,name,val,ntabs=0,extra_first_pref=""): - pref="\t"*ntabs - if (None is name) and isinstance(val,Unfinished): + def print_named_val(self, name, val, ntabs=0, extra_first_pref=""): + pref = "\t"*ntabs + if (None is name) and isinstance(val, Unfinished): name = val.name - if isinstance(val,UnfinishedSequence): - print(pref,extra_first_pref," "+encoder_name+":",name) + if isinstance(val, UnfinishedSequence): + print(pref, extra_first_pref, " "+ENCODER_NAME+":", name) if self.show_sequence_examples: if self.sequence_prints_verbose: - print(pref,"\t Example:",end="") - optional_exampledesc = name+"("+formatstr(self.sequence_running_example)+") =" - print_seq(self.selector_running_example,val(self.sequence_running_example),still_on_prev_line=True, - extra_pref=pref,lastpref_if_shortprint=optional_exampledesc) + print(pref, "\t Example:", end="") + optional_exampledesc = name + \ + "("+formatstr(self.sequence_running_example)+") =" + print_seq(self.selector_running_example, + val(self.sequence_running_example), + still_on_prev_line=True, + extra_pref=pref, + lastpref_if_shortprint=optional_exampledesc) else: - print(pref,"\t Example:",name+"("+formatstr(self.sequence_running_example)+\ - ") =",val(self.sequence_running_example)) - elif isinstance(val,UnfinishedSelect): - print(pref,extra_first_pref," selector:",name) + print(pref, "\t Example:", name + "(" + + formatstr(self.sequence_running_example) + ") =", + val(self.sequence_running_example)) + elif isinstance(val, UnfinishedSelect): + print(pref, extra_first_pref, " selector:", name) if self.show_selector_examples: - print(pref,"\t Example:")#,name+"("+formatstr(self.selector_running_example)+") =") - print_select(self.selector_running_example,val(self.selector_running_example),extra_pref=pref) - elif isinstance(val,RASPFunction): - print(pref,extra_first_pref," "+str(val)) - elif isinstance(val,list): - named = " list: "+((name+" = ") if not None is name else "") - print(pref,extra_first_pref,named,end="") - flat = True not in [isinstance(v,list) or isinstance(v,dict) or isinstance(v,Unfinished) for v in val] + print(pref, "\t Example:") + print_select(self.selector_running_example, val( + self.selector_running_example), extra_pref=pref) + elif isinstance(val, RASPFunction): + print(pref, extra_first_pref, " "+str(val)) + elif isinstance(val, list): + named = " list: "+((name+" = ") if name is not None else "") + print(pref, extra_first_pref, named, end="") + flat = True not in [isinstance(v, list) or isinstance( + v, dict) or isinstance(v, Unfinished) for v in val] if flat: print(val) else: - print(pref,"[") + print(pref, "[") for v in val: - self.print_named_val(None,v,ntabs=ntabs+2) - print(pref," "*len(named),"]") - elif isinstance(val,dict): - named = " dict: "+((name+" = ") if not None is name else "") - print(pref,extra_first_pref,named,end="") - flat = True not in [isinstance(val[v],list) or isinstance(val[v],dict) or isinstance(val[v],Unfinished) for v in val] + self.print_named_val(None, v, ntabs=ntabs+2) + print(pref, " "*len(named), "]") + elif isinstance(val, dict): + named = " dict: "+((name+" = ") if name is not None else "") + print(pref, extra_first_pref, named, end="") + flat = True not in [isinstance(val[v], list) or isinstance( + val[v], dict) or isinstance(val[v], Unfinished) for v in val] if flat: print(val) else: - print(pref,"{") + print(pref, "{") for v in val: - self.print_named_val(None,val[v],ntabs=ntabs+3,extra_first_pref=formatstr(v)+" : ") - print(pref," "*len(named),"}") + self.print_named_val(None, val[v], ntabs=ntabs + 3, + extra_first_pref=formatstr(v) + " : ") + print(pref, " "*len(named), "}") else: - print(pref," value:",((name+" = ") if not None is name else ""),formatstr(val)) - - def print_example(self,nres): - if nres.subset in ["both",encoder_name]: - print("\t"+encoder_name+" example:",formatstr(self.sequence_running_example)) - if nres.subset in ["both","selector"]: - print("\tselector example:",formatstr(self.selector_running_example)) - - def print_result(self,rp): + print(pref, " value:", ((name+" = ") + if name is not None else ""), formatstr(val)) + + def print_example(self, nres): + if nres.subset in ["both", ENCODER_NAME]: + print("\t"+ENCODER_NAME+" example:", + formatstr(self.sequence_running_example)) + if nres.subset in ["both", "selector"]: + print("\tselector example:", formatstr( + self.selector_running_example)) + + def print_result(self, rp): if self.silent: return - if isinstance(rp,LazyPrint): + if isinstance(rp, LazyPrint): return rp.print() - if isinstance(rp,list): # a list of multiple ResultToPrint s -- probably the result of a multi-assignment + # a list of multiple ResultToPrint s -- probably the result of a + # multi-assignment + if isinstance(rp, list): for v in rp: - self.print_result(v) + self.print_result(v) return if not rp.print: return res = rp.res - if isinstance(res,NamedVal): - self.print_named_val(res.name,res.val) - elif isinstance(res,ReturnExample): + if isinstance(res, NamedVal): + self.print_named_val(res.name, res.val) + elif isinstance(res, ReturnExample): self.print_example(res) - elif isinstance(res,JustVal): + elif isinstance(res, JustVal): self.print_just_val(res) - def evaluate_replstatement(self,ast): + def evaluate_replstatement(self, ast): if ast.setExample(): return ResultToPrint(self.setExample(ast.setExample()), False) if ast.showExample(): return ResultToPrint(self.showExample(ast.showExample()), True) if ast.toggleExample(): - return ResultToPrint(self.toggleExample(ast.toggleExample()), False) + return ResultToPrint(self.toggleExample(ast.toggleExample()), + False) if ast.toggleSeqVerbose(): - return ResultToPrint(self.toggleSeqVerbose(ast.toggleSeqVerbose()), False) + return ResultToPrint(self.toggleSeqVerbose(ast.toggleSeqVerbose()), + False) if ast.exit(): raise StopException() - def toggleSeqVerbose(self,ast): + def toggleSeqVerbose(self, ast): switch = ast.switch.text self.sequence_prints_verbose = switch == "on" - def toggleExample(self,ast): + def toggleExample(self, ast): subset = ast.subset subset = "both" if not subset else subset.text switch = ast.switch.text - examples_on = switch=="on" - if subset in ["both",encoder_name]: + examples_on = switch == "on" + if subset in ["both", ENCODER_NAME]: self.show_sequence_examples = examples_on - if subset in ["both","selector"]: + if subset in ["both", "selector"]: self.show_selector_examples = examples_on - def showExample(self,ast): + def showExample(self, ast): subset = ast.subset subset = "both" if not subset else subset.text return ReturnExample(subset) - def setExample(self,ast): - example = Evaluator(self.env,self).evaluateExpr(ast.example) - if not isinstance(example,Iterable): + def setExample(self, ast): + example = Evaluator(self.env, self).evaluateExpr(ast.example) + if not isinstance(example, Iterable): raise RASPTypeError("example not iterable: "+str(example)) subset = ast.subset subset = "both" if not subset else subset.text - self.set_running_example(example,subset) + self.set_running_example(example, subset) return ReturnExample(subset) - def loadFile(self,ast,calling_env=None): + def loadFile(self, ast, calling_env=None): if None is calling_env: calling_env = self.env libname = ast.filename.text[1:-1] filename = libname + ".rasp" try: - with open(filename,"r") as f: - prev_example_settings = self.show_sequence_examples, self.show_selector_examples - self.show_sequence_examples, self.show_selector_examples = False, False - self.run(fromfile=f,env = Environment(name=libname,parent_env=self.base_env,stealing_env=calling_env),store_prints=True) - self.filter_and_dump_prints() - self.show_sequence_examples, self.show_selector_examples = prev_example_settings + with open(filename, "r") as f: + prev_example_settings = self.show_sequence_examples, \ + self.show_selector_examples + self.show_sequence_examples = False + self.show_selector_examples = False + self.run(fromfile=f, + env=Environment(name=libname, + parent_env=self.base_env, + stealing_env=calling_env), + store_prints=True) + self.filter_and_dump_prints() + self.show_sequence_examples, self.show_selector_examples = \ + prev_example_settings except FileNotFoundError: raise LoadError("could not find file: "+filename) - def get_tree(self,fromfile=None): + def get_tree(self, fromfile=None): try: return LineReader(fromfile=fromfile).get_input_tree() except AntlrException as e: - print("\t!! antlr exception:",e.msg,"\t-- ignoring input") + print("\t!! antlr exception:", e.msg, "\t-- ignoring input") return None - def run_given_line(self,line): + def run_given_line(self, line): try: tree = LineReader(given_line=line).get_input_tree() - if isinstance(tree,Stop): + if isinstance(tree, Stop): return None rp = self.evaluate_tree(tree) - if isinstance(rp,LazyPrint): - rp.print() # error messages get raised, but ultimately have to be printed somewhere if not caught? idk + if isinstance(rp, LazyPrint): + # error messages get raised, but ultimately have to be printed + # somewhere if not caught? idk + rp.print() except AntlrException as e: - print("\t!! REPL failed to run initiating line:",line) - print("\t --got antlr exception:",e.msg) + print("\t!! REPL failed to run initiating line:", line) + print("\t --got antlr exception:", e.msg) return None - def assigned_to_top(self,res,env): + def assigned_to_top(self, res, env): if env is self.env: return True - # we are now definitely inside some file, the question is whether we have taken - # the result and kept it in the top level too, i.e., whether we have imported a non-private value. - # checking whether it is also in self.env, even identical, will not tell us much as it may have been here and the same - # already. so we have to replicate the logic here. - if not isinstance(res,NamedVal): - return False # only namedvals get set to begin with - if res.name.startswith("_") or (res.name=="out"): + # we are now definitely inside some file, the question is whether we + # have taken the result and kept it in the top level too, i.e., whether + # we have imported a non-private value. checking whether it is also in + # self.env, even identical, will not tell us much as it may have been + # here and the same already. so we have to replicate the logic here. + if not isinstance(res, NamedVal): + return False # only namedvals get set to begin with + if res.name.startswith("_") or (res.name == "out"): return False return True - def evaluate_tree(self,tree,env=None): - if None is env: - env = self.env # otherwise, can pass custom env - # (e.g. when loading from a file, make env for that file, - # to keep that file's private (i.e. underscore-prefixed) variables to itself) + def evaluate_tree(self, tree, env=None): + if None is env: + env = self.env # otherwise, can pass custom env + # (e.g. when loading from a file, make env for that file, + # to keep that file's private (i.e. underscore-prefixed) variables + # to itself) if None is tree: - return ResultToPrint(None,False) + return ResultToPrint(None, False) try: if tree.replstatement(): return self.evaluate_replstatement(tree.replstatement()) - elif tree.raspstatement(): - res = Evaluator(env,self).evaluate(tree.raspstatement()) - if isinstance(res,NamedValList): - return [ResultToPrint(r,self.assigned_to_top(r,env)) for r in res.nvs] - return ResultToPrint(res, self.assigned_to_top(res,env)) + elif tree.raspstatement(): + res = Evaluator(env, self).evaluate(tree.raspstatement()) + if isinstance(res, NamedValList): + return [ResultToPrint(r, self.assigned_to_top(r, env)) for + r in res.nvs] + return ResultToPrint(res, self.assigned_to_top(res, env)) except (UndefinedVariable, ReservedName) as e: - return LazyPrint("\t\t!!ignoring input:\n\t",e) + return LazyPrint("\t\t!!ignoring input:\n\t", e) except NotImplementedError: return LazyPrint("not implemented this command yet! ignoring") - except (ArgsError,RASPTypeError,LoadError,RASPValueError) as e: - return LazyPrint("\t\t!!ignoring input:\n\t",e) + except (ArgsError, RASPTypeError, LoadError, RASPValueError) as e: + return LazyPrint("\t\t!!ignoring input:\n\t", e) # if not replstatement or raspstatement, then comment - return ResultToPrint(None,False) + return ResultToPrint(None, False) def filter_and_dump_prints(self): - # TODO: some error messages are still rising up and getting printed before reaching this position :( + # TODO: some error messages are still rising up and getting printed + # before reaching this position :( def filter_named_val_reps(rps): - # do the filtering. no namedvallists here - those are converted into a list of ResultToPrint s - # containing NamedVal s immediately after receiving them in evaluate_tree + # do the filtering. no namedvallists here - those are converted + # into a list of ResultToPrint s containing NamedVal s immediately + # after receiving them in evaluate_tree res = [] names = set() - for r in rps[::-1]: # go backwards - want to print the last occurence of each named item, not first, so filter works backwards - if isinstance(r.res,NamedVal): + # go backwards - want to print the last occurence of each named + # item, not first, so filter works backwards + for r in rps[::-1]: + if isinstance(r.res, NamedVal): if r.res.name in names: continue names.add(r.res.name) res.append(r) - return res[::-1] # flip back forwards + return res[::-1] # flip back forwards - if not True in [isinstance(v,LazyPrint) for v in self.results_to_print]: - self.results_to_print = filter_named_val_reps(self.results_to_print) + if True not in [isinstance(v, LazyPrint) for + v in self.results_to_print]: + self.results_to_print = filter_named_val_reps( + self.results_to_print) # if isinstance(res,NamedVal): # self.print_named_val(res.name,res.val) # # print all that needs to be printed: for r in self.results_to_print: - if isinstance(r,LazyPrint): + if isinstance(r, LazyPrint): r.print() else: self.print_result(r) # clear the list self.results_to_print = [] - - - - def run(self,fromfile=None,env=None,store_prints=False): - def careful_print(*a,**kw): + def run(self, fromfile=None, env=None, store_prints=False): + def careful_print(*a, **kw): if store_prints: - self.results_to_print.append(LazyPrint(*a,**kw)) + self.results_to_print.append(LazyPrint(*a, **kw)) else: - print(*a,**kw) + print(*a, **kw) while True: try: tree = self.get_tree(fromfile) - if isinstance(tree,Stop): + if isinstance(tree, Stop): break - rp = self.evaluate_tree(tree,env) + rp = self.evaluate_tree(tree, env) if store_prints: - if isinstance(rp,list): - self.results_to_print += rp # multiple results given - a multi-assignment + if isinstance(rp, list): + # multiple results given - a multi-assignment + self.results_to_print += rp else: self.results_to_print.append(rp) else: self.print_result(rp) except RASPTypeError as e: - careful_print("\t!!statement executed, but result fails on evaluation:\n\t\t",e) + msg = "\t!!statement executed, but result fails on evaluation:" + msg += "\n\t\t" + careful_print(msg, e) except EOFError: careful_print("") break except StopException: break except KeyboardInterrupt: - careful_print("") # makes newline + careful_print("") # makes newline except Exception as e: - if debug: + if DEBUG: raise e - careful_print("something went wrong:",e) - - - - -from antlr4.error.ErrorListener import ErrorListener + careful_print("something went wrong:", e) class AntlrException(Exception): - def __init__(self,msg): + def __init__(self, msg): self.msg = msg + class InputNotFinished(Exception): def __init__(self): pass -class MyErrorListener( ErrorListener ): + +class MyErrorListener(ErrorListener): def __init__(self): super(MyErrorListener, self).__init__() @@ -374,99 +415,118 @@ def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): raise InputNotFinished() if msg.startswith("missing ';' at"): raise InputNotFinished() - if "mismatched input" in msg: - a=str(offendingSymbol) - b=a[a.find("=")+2:] - c=b[:b.find(",<")-1] + # TODO: why did this do nothing? + # if "mismatched input" in msg: + # a = str(offendingSymbol) + # b = a[a.find("=")+2:] + # c = b[:b.find(",<")-1] ae = AntlrException(msg) - ae.recognizer, ae.offendingSymbol, ae.line, ae.column, ae.msg, ae.e = recognizer, offendingSymbol, line, column, msg, e + ae.recognizer = recognizer + ae.offendingSymbol = offendingSymbol + ae.line = line + ae.column = column + ae.msg = msg + ae.e = e raise ae - # def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs): + # def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, + # ambigAlts, configs): # raise AntlrException("ambiguity") - # def reportAttemptingFullContext(self, recognizer, dfa, startIndex, stopIndex, conflictingAlts, configs): + # def reportAttemptingFullContext(self, recognizer, dfa, startIndex, + # stopIndex, conflictingAlts, configs): # we're ok with this: happens with func defs it seems - # def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): + # def reportContextSensitivity(self, recognizer, dfa, startIndex, + # stopIndex, prediction, configs): # we're ok with this: happens with func defs it seems - + + class Stop: def __init__(self): pass + class LineReader: - def __init__(self,prompt=">>",fromfile=None,given_line=None): + def __init__(self, prompt=">>", fromfile=None, given_line=None): self.fromfile = fromfile self.given_line = given_line self.prompt = prompt + " " self.cont_prompt = "."*len(prompt)+" " - def str_to_antlr_parser(self,s): + def str_to_antlr_parser(self, s): antlrinput = InputStream(s) - lexer = RASPLexer(antlrinput) + lexer = RASPLexer(antlrinput) lexer.removeErrorListeners() - lexer.addErrorListener( MyErrorListener() ) - stream = CommonTokenStream(lexer) - parser = RASPParser(stream) - parser.removeErrorListeners() - parser.addErrorListener( MyErrorListener() ) + lexer.addErrorListener(MyErrorListener()) + stream = CommonTokenStream(lexer) + parser = RASPParser(stream) + parser.removeErrorListeners() + parser.addErrorListener(MyErrorListener()) return parser - - def read_line(self,continuing=False,nest_depth=0): + def read_line(self, continuing=False, nest_depth=0): prompt = self.cont_prompt if continuing else self.prompt - if not None is self.fromfile: + if self.fromfile is not None: res = self.fromfile.readline() - if not res: # python files return "" on last line (as opposed to "\n" on empty lines) + # python files return "" on last line (as opposed to "\n" on empty + # lines) + if not res: return Stop() return res - if not None is self.given_line: + if self.given_line is not None: res = self.given_line self.given_line = Stop() return res else: return input(prompt+(" "*nest_depth)) - def get_input_tree(self): - pythoninput="" + pythoninput = "" multiline = False while True: + nest_depth = pythoninput.split().count("def") newinput = self.read_line(continuing=multiline, - nest_depth=pythoninput.split().count("def")) - if isinstance(newinput,Stop): # input stream ended + nest_depth=nest_depth) + if isinstance(newinput, Stop): # input stream ended return Stop() if is_comment(newinput): - newinput = "" # don't let comments get in and ruin things somehow - pythoninput += newinput # don't replace newlines! this is how in-function comments get broken .replace("\n","")+" " + # don't let comments get in and ruin things somehow + newinput = "" + # don't replace newlines here! this is how in-function comments get + # broken + pythoninput += newinput parser = self.str_to_antlr_parser(pythoninput) try: res = parser.r().statement() - if isinstance(res,list): - # TODO: this seems to happen when there's ambiguity. figure out what is going on!! - assert len(res)==1 + if isinstance(res, list): + # TODO: this seems to happen when there's ambiguity. figure + # out what is going on!! + assert len(res) == 1 res = res[0] return res except InputNotFinished: multiline = True - pythoninput+=" " + pythoninput += " " - -def print_seq(example,seq,still_on_prev_line=False,extra_pref="",lastpref_if_shortprint=""): - if len(set(seq.get_vals()))==1: +def print_seq(example, seq, still_on_prev_line=False, extra_pref="", + lastpref_if_shortprint=""): + if len(set(seq.get_vals())) == 1: print(extra_pref if not still_on_prev_line else "", - lastpref_if_shortprint, - str(seq),end=" ") - print("[skipped full display: identical values]")# when there is only one value, it's nicer to just print that than the full list, verbosity be damned + lastpref_if_shortprint, + str(seq), end=" ") + # when there is only one value, it's nicer to just print that than the + # full list, verbosity be damned + print("[skipped full display: identical values]") return if still_on_prev_line: print("") seq = seq.get_vals() + def cleanboolslist(seq): - if isinstance(seq[0],bool): + if isinstance(seq[0], bool): tstr = "T" if seq.count(True) <= seq.count(False) else "" fstr = "F" if seq.count(False) <= seq.count(True) else "" return [tstr if v else fstr for v in seq] @@ -480,21 +540,23 @@ def cleanboolslist(seq): seq = [str(v) for v in seq] maxlen = max(len(v) for v in example+seq) - def neatline(seq): def padded(s): return " "*(maxlen-len(s))+s return " ".join(padded(v) for v in seq) - print(extra_pref,"\t\tinput: ",neatline(example),"\t","("+lazy_type_check(example)+"s)") - print(extra_pref,"\t\toutput: ",neatline(seq),"\t","("+seqtype+"s)") + print(extra_pref, "\t\tinput: ", neatline(example), + "\t", "("+lazy_type_check(example)+"s)") + print(extra_pref, "\t\toutput: ", neatline(seq), "\t", "("+seqtype+"s)") + -def print_select(example,select,extra_pref=""): +def print_select(example, select, extra_pref=""): # .replace("\n","\n\t\t\t") def nice_matrix_line(m): return " ".join("1" if v else " " for v in m) - print(extra_pref,"\t\t\t "," ".join(str(v) for v in example)) + print(extra_pref, "\t\t\t ", " ".join(str(v) for v in example)) matrix = select.get_vals() - [print(extra_pref,"\t\t\t",v,"|",nice_matrix_line(matrix[m])) for v,m in zip(example,matrix)] + [print(extra_pref, "\t\t\t", v, "|", nice_matrix_line(matrix[m])) + for v, m in zip(example, matrix)] if __name__ == "__main__": @@ -514,5 +576,5 @@ def runner(): a.run() except Exception as e: print(e) - return a,e - return a,None \ No newline at end of file + return a, e + return a, None diff --git a/RASP_support/Sugar.py b/RASP_support/Sugar.py index c9645d3..0923e75 100644 --- a/RASP_support/Sugar.py +++ b/RASP_support/Sugar.py @@ -1,80 +1,98 @@ -from FunctionalSupport import indices, tokens_str, tokens_int, tokens_float, tokens_asis, \ -tokens_bool, or_selects, and_selects, not_select -from FunctionalSupport import select, aggregate, zipmap -from FunctionalSupport import UnfinishedSequence as _UnfinishedSequence, Unfinished as _Unfinished -from FunctionalSupport import guarded_compare as _guarded_compare -from FunctionalSupport import guarded_contains as _guarded_contains -import DrawCompFlow # not at all necessary for sugar, but sugar is really the top-level tpl file we import, -# and nice to have draw_comp_flow added into the sequences already on load -from collections.abc import Iterable +from FunctionalSupport import Unfinished as _Unfinished +from FunctionalSupport import UnfinishedSequence as _UnfinishedSequence +from FunctionalSupport import select, zipmap from make_operators import add_ops +import DrawCompFlow +# DrawCompFlow is not at all necessary for sugar, but sugar is really the +# top-level rasp file we import, and nice to have draw_comp_flow added into +# the sequences already on load -def _apply_unary_op(self,f): - return zipmap(self,f) +def _apply_unary_op(self, f): + return zipmap(self, f) -def _apply_binary_op(self,other,f): - def seq_and_other_op(self,other,f): - return zipmap(self,lambda a:f(a,other)) - def seq_and_seq_op(self,other_seq,f): - return zipmap((self,other_seq),f) - if isinstance(other,_UnfinishedSequence): - return seq_and_seq_op(self,other,f) + +def _apply_binary_op(self, other, f): + def seq_and_other_op(self, other, f): + return zipmap(self, lambda a: f(a, other)) + + def seq_and_seq_op(self, other_seq, f): + return zipmap((self, other_seq), f) + if isinstance(other, _UnfinishedSequence): + return seq_and_seq_op(self, other, f) else: - return seq_and_other_op(self,other,f) + return seq_and_other_op(self, other, f) + -add_ops(_UnfinishedSequence,_apply_unary_op,_apply_binary_op) +add_ops(_UnfinishedSequence, _apply_unary_op, _apply_binary_op) -def _addname(seq,name,default_name,always_display_when_named=True): - if None is name: - res = seq.setname(default_name,always_display_when_named=always_display_when_named).allow_suppressing_display() +def _addname(seq, name, default_name, always_display_when_named=True): + if name is None: + res = seq.setname(default_name, + always_display_when_named=always_display_when_named) + res = res.allow_suppressing_display() else: - res = seq.setname(name,always_display_when_named=always_display_when_named) + res = seq.setname(name, + always_display_when_named=always_display_when_named) return res -full_s = select((),(),lambda :True,name="full average",compare_string="full average") +full_s = select((), (), lambda: True, name="full average", + compare_string="full average") -def tplconst(v,name=None): - return _addname(zipmap((),lambda :v),name,"constant: "+str(v),always_display_when_named=False).mark_as_constant() - # always_display_when_named = False : constants aren't worth displaying, but still going to name them in background, - # in case change mind about this -# allow suppressing display for bool, not, and, or : all of these would have been boring operators if -# only python let me overload them +def tplconst(v, name=None): + return _addname(zipmap((), lambda: v), name, "constant: " + str(v), + always_display_when_named=False).mark_as_constant() + # always_display_when_named = False : constants aren't worth displaying, + # but still going to name them in background, in case I change my mind + +# allow suppressing display for bool, not, and, or : all of these would have +# been boring operators if only python let me overload them + +# always have to call allow_suppressing_display after setname because setname +# marks the variable as crucial to display under assumption user named it -# always have to call allow_suppressing_display after setname because setname marks the variable as -# crucial to display under assumption user named it def toseq(seq): - if not isinstance(seq,_UnfinishedSequence): - seq = tplconst(seq,str(seq)) + if not isinstance(seq, _UnfinishedSequence): + seq = tplconst(seq, str(seq)) return seq + def asbool(seq): - res = zipmap(seq,lambda a:bool(a)) - return _addname(res,None,"bool("+seq.name+")") - # would do res = seq==True but it seems this has different behaviour to bool eg 'bool(2)' - # is True but '2==True' returns False + res = zipmap(seq, lambda a: bool(a)) + return _addname(res, None, "bool(" + seq.name + ")") + # would do res = seq==True but it seems this has different behaviour to + # bool eg 'bool(2)' is True but '2==True' returns False -def tplnot(seq,name=None): - res = asbool(seq) == False # this one does correct conversion using asbool and then we really can just do ==False - return _addname(res,name,"( not "+str(seq.name)+" )") -def _num_trues(l,r): - l,r = toseq(l),toseq(r) - return (1*asbool(l)) + (1*asbool(r)) +def tplnot(seq, name=None): + # this one does correct conversion using asbool and then we really can just + # do ==False + res = asbool(seq) == False + return _addname(res, name, "( not " + str(seq.name) + " )") + + +def _num_trues(left, right): + l, r = toseq(left), toseq(right) + return (1 * asbool(l)) + (1 * asbool(r)) + def quickname(v): - if isinstance(v,_Unfinished): + if isinstance(v, _Unfinished): return v.name else: return str(v) -def tpland(l,r): - res = _num_trues(l,r) == 2 - return _addname(res,None,"( "+quickname(l)+" and "+quickname(r)+")") -def tplor(l,r): - res = _num_trues(l,r) >= 1 - return _addname(res,None,"( "+quickname(l)+" or "+quickname(r)+")") +def tpland(left, right): + res = _num_trues(left, right) == 2 + return _addname(res, None, "( " + quickname(left) + " and " + + quickname(right) + ")") + + +def tplor(left, right): + res = _num_trues(left, right) >= 1 + return _addname(res, None, "( " + quickname(left) + " or " + + quickname(right) + ")") diff --git a/RASP_support/Support.py b/RASP_support/Support.py index bf32667..feec2e6 100644 --- a/RASP_support/Support.py +++ b/RASP_support/Support.py @@ -1,74 +1,86 @@ -import functools -from copy import deepcopy # dont let them mutate the things i'm allowing them to have as vals +# dont let them mutate the things i'm allowing them to have as vals +from copy import deepcopy import pprint + class RASPError(Exception): - def __init__(self,*a): + def __init__(self, *a): super().__init__(" ".join([str(b) for b in a])) + class RASPTypeError(RASPError): - def __init__(self,*a): + def __init__(self, *a): super().__init__(*a) -def clean_val(num,digits=3): # taken from my helper functions - res = round(num,digits) + +def clean_val(num, digits=3): # taken from my helper functions + res = round(num, digits) if digits == 0: res = int(res) return res + class SupportException(Exception): - def __init__(self,m): - Exception.__init__(self,m) + def __init__(self, m): + Exception.__init__(self, m) + TBANNED = "banned" TMISMATCHED = "mismatched" -TNAME = {bool:"bool",str:"string",int:"int",float:"float"} -NUMTYPES = [TNAME[int],TNAME[float]] +TNAME = {bool: "bool", str: "string", int: "int", float: "float"} +NUMTYPES = [TNAME[int], TNAME[float]] sorted_typenames_list = sorted(list(TNAME.values())) -legal_types_list_string = ", ".join(sorted_typenames_list[:-1])+" or "+sorted_typenames_list[-1] +legal_types_list_string = ", ".join( + sorted_typenames_list[:-1])+" or "+sorted_typenames_list[-1] + + +def is_in_types(v, tlist): + for t in tlist: + if isinstance(v, t): + return True + return False -def is_in_types(v,tlist): - for t in tlist: - if isinstance(v,t): - return True - return False def lazy_type_check(vals): - legal_val_types = [str,bool,int,float] - number_types = [int,float] + legal_val_types = [str, bool, int, float] + number_types = [int, float] # all vals are same, legal, type: for t in legal_val_types: - b = [isinstance(v,t) for v in vals] + b = [isinstance(v, t) for v in vals] if False not in b: return TNAME[t] - # allow vals to also be mixed integers and ints, treat those as floats + # allow vals to also be mixed integers and ints, treat those as floats # (but don't actually change the ints to floats, want neat printouts) - b = [is_in_types(v,number_types) for v in vals] + b = [is_in_types(v, number_types) for v in vals] if False not in b: return TNAME[float] # from here it's all bad, but lets have some clear error messages - b = [is_in_types(v,legal_val_types) for v in vals] + b = [is_in_types(v, legal_val_types) for v in vals] if False not in b: - return TMISMATCHED # all legal types, but mismatched + return TMISMATCHED # all legal types, but mismatched else: - return TBANNED + return TBANNED class Sequence: - def __init__(self,vals): + def __init__(self, vals): self.type = lazy_type_check(vals) if self.type == TMISMATCHED: - raise RASPTypeError(f"attempted to create sequence with vals of different types:\n\t\t {vals}") + raise RASPTypeError( + "attempted to create sequence with vals of different types:" + + f"\n\t\t {vals}") if self.type == TBANNED: - raise RASPTypeError(f"attempted to create sequence with illegal val types (vals must be {legal_types_list_string}):\n\t\t {vals}") + raise RASPTypeError( + "attempted to create sequence with illegal val types " + + f"(vals must be {legal_types_list_string}):\n\t\t {vals}") self._vals = vals def __str__(self): # return "Sequence"+str([small_str(v) for v in self._vals]) - if (len(set(self._vals))==1) and (len(self._vals)>1): + if (len(set(self._vals)) == 1) and (len(self._vals) > 1): res = "["+small_str(self._vals[0])+"]*"+str(len(self._vals)) else: res = "["+", ".join(small_str(v) for v in self._vals)+"]" @@ -84,34 +96,36 @@ def get_vals(self): return deepcopy(self._vals) -def dims_match(seqs,expected_dim): +def dims_match(seqs, expected_dim): return False not in [expected_dim == len(seq) for seq in seqs] + class Select: - def __init__(self, n, q_vars, k_vars, f): + def __init__(self, n, q_vars, k_vars, f): self.n = n - self.makeselect(q_vars,k_vars,f) + self.makeselect(q_vars, k_vars, f) self.niceprint = None def get_vals(self): - if None is self.select: + if self.select is None: self.makeselect() return deepcopy(self.select) - def makeselect(self,q_vars=None,k_vars=None,f=None): + def makeselect(self, q_vars=None, k_vars=None, f=None): if None is q_vars: assert (None is k_vars) and (None is f) - q_vars = (Sequence(self.target_index),) + q_vars = (Sequence(self.target_index),) k_vars = (Sequence(list(range(self.n))),) - f = lambda t,i:t==i - self.select = {i:[f(*get(q_vars,i),*get(k_vars,j)) for j in range(self.n)] - for i in range(self.n)} # outputs of f should be - # True or False. j goes along input dim, i along output + def f(t, i): return t == i + self.select = {i: [f(*get(q_vars, i), *get(k_vars, j)) + for j in range(self.n)] + for i in range(self.n)} # outputs of f should be + # True or False. j goes along input dim, i along output def __str__(self): - select = self.get_vals() + self.get_vals() if None is self.niceprint: - d = {i:list(map(int,self.select[i])) for i in self.select} + d = {i: list(map(int, self.select[i])) for i in self.select} self.niceprint = str(self.niceprint) if len(str(d)) > 40: starter = "\n" @@ -120,119 +134,144 @@ def __str__(self): starter = "" self.niceprint = str(d) self.niceprint = starter + self.niceprint - return self.niceprint + return self.niceprint def __repr__(self): return str(self) -def select(n,q_vars,k_vars,f): - return Select(n,q_vars,k_vars,f) -## applying selects or feedforward (map) -def aggregate(select,k_vars,func,default=None): - return to_sequences(apply_average_select(select,k_vars,func,default)) +def select(n, q_vars, k_vars, f): + return Select(n, q_vars, k_vars, f) + +# applying selects or feedforward (map) + + +def aggregate(select, k_vars, func, default=None): + return to_sequences(apply_average_select(select, k_vars, func, default)) + def to_sequences(results_by_index): def totup(r): - if not isinstance(r,tuple): + if not isinstance(r, tuple): return (r,) return r - results_by_index = list(map(totup,results_by_index)) # convert scalar results to tuples of length 1 - results_by_output_val = list(zip(*results_by_index)) # one list (sequence) per output value - res = tuple(map(Sequence,results_by_output_val)) + # convert scalar results to tuples of length 1 + results_by_index = list(map(totup, results_by_index)) + # one list (sequence) per output value + results_by_output_val = list(zip(*results_by_index)) + res = tuple(map(Sequence, results_by_output_val)) if len(res) == 1: return res[0] else: return res -def zipmap(n,k_vars,func): - # assert len(k_vars) >= 1, "dont make a whole sequence for a plain constant you already know the value of.." - results_by_index = [func(*get(k_vars,i)) for i in range(n)] + +def zipmap(n, k_vars, func): + # assert len(k_vars) >= 1, "dont make a whole sequence for a plain constant + # you already know the value of.." + results_by_index = [func(*get(k_vars, i)) for i in range(n)] return to_sequences(results_by_index) -def verify_default_size(default,num_output_vars): + +def verify_default_size(default, num_output_vars): assert num_output_vars > 0 if num_output_vars == 1: - assert not isinstance(default,tuple), "aggregates on functions with single output should have scalar default" + errnote = "aggregates on functions with single output should have" \ + + " scalar default" + assert not isinstance(default, tuple), errnote elif num_output_vars > 1: - assert isinstance(default,tuple) and len(default)==num_output_vars,\ - "for function with >1 output values, default should be tuple of default \ - values, of equal length to passed function's output values (for function \ - with single output value, default should be single value too)" + errnote = "for function with >1 output values, default should be" \ + + " tuple of default values, of equal length to passed" \ + + " function's output values (for function with single output" \ + + " value, default should be single value too)" + check = isinstance(default, tuple) and len(default) == num_output_vars + assert check, errnote -def apply_average_select(select,k_vars,func,default=0): + +def apply_average_select(select, k_vars, func, default=0): def apply_func_to_each_index(): - kvs = [get(k_vars,i) for i in list(range(select.n))] # kvs is list [by index] of lists [by varname] of values - candidate_i = [func(*kvi) for kvi in kvs] # candidate output per index + # kvs is list [by index] of lists [by varname] of values + kvs = [get(k_vars, i) for i in list(range(select.n))] + candidate_i = [func(*kvi) for kvi in kvs] # candidate output per index if num_output_vars > 1: candidates_by_varname = list(zip(*candidate_i)) else: - candidates_by_varname = (candidate_i,) # expect tuples of values for conversions in return_sequences + # expect tuples of values for conversions in return_sequences + candidates_by_varname = (candidate_i,) return candidates_by_varname - def prep_default(default,num_output_vars): + def prep_default(default, num_output_vars): if None is default: default = 0 - # output of average is always floats, so will be converting all - # to floats here else we'll fail the lazy type check in the Sequences. - # (and float(None) doesn't 'compile' ) + # output of average is always floats, so will be converting all + # to floats here else we'll fail the lazy type check in the + # Sequences. (and float(None) doesn't 'compile' ) # TODO: maybe just lose the lazy type check? - if not isinstance(default,tuple) and (num_output_vars>1): - default = tuple([default]*num_output_vars) + if not isinstance(default, tuple) and (num_output_vars > 1): + default = tuple([default]*num_output_vars) # *specifically* in apply_average, where values have to be floats, - # allow default to be single val, - #that will be repeated for all wanted outputs - verify_default_size(default,num_output_vars) - if not isinstance(default,tuple): - default = (default,) # specifically with how we're going to do things here in the average aggregate, - # will help to actually have the outputs get passed around as tuples, even if they're scalars really. - # but do this after the size check for the scalar one so it doesn't get filled with weird ifs... this - # tupled scalar thing is only a convenience in this implementation in this here function + # allow default to be single val, + # that will be repeated for all wanted outputs + verify_default_size(default, num_output_vars) + if not isinstance(default, tuple): + # specifically with how we're going to do things here in the + # average aggregate, will help to actually have the outputs get + # passed around as tuples, even if they're scalars really. + # but do this after the size check for the scalar one so it doesn't + # get filled with weird ifs... this tupled scalar thing is only a + # convenience in this implementation in this here function + default = (default,) return default - def apply_and_average_single_index(outputs_by_varname,index, - index_scores,num_output_vars,default): - def mean(scores,vals): - n = scores.count(True) # already >0 by earlier + def apply_and_average_single_index(outputs_by_varname, index, + index_scores, num_output_vars, default): + def mean(scores, vals): + n = scores.count(True) # already >0 by earlier if n == 1: return vals[scores.index(True)] # else # n>1 if not (lazy_type_check(vals) in NUMTYPES): - raise Exception("asked to average multiple values, but they are non-numbers: "+str(vals)) - return sum([v for s,v in zip(scores,vals) if s])*1.0/n - + raise Exception( + "asked to average multiple values, but they are " + + "non-numbers: " + str(vals)) + return sum([v for s, v in zip(scores, vals) if s])*1.0/n + num_influencers = index_scores.count(True) if num_influencers == 0: return default else: - return tuple(mean(index_scores,o_by_i) for o_by_i in outputs_by_varname) # return_sequences expects multiple outputs to be in tuple form - num_output_vars = get_num_outputs(func(*get(k_vars,0))) + # return_sequences expects multiple outputs to be in tuple form + return tuple(mean(index_scores, o_by_i) + for o_by_i in outputs_by_varname) + num_output_vars = get_num_outputs(func(*get(k_vars, 0))) candidates_by_varname = apply_func_to_each_index() - default = prep_default(default,num_output_vars) + default = prep_default(default, num_output_vars) means_per_index = [apply_and_average_single_index(candidates_by_varname, - i,select.select[i],num_output_vars,default) - for i in range(select.n)] + i, select.select[i], + num_output_vars, default) + for i in range(select.n)] # list (per index) of all the new variable values (per varname) return means_per_index -def get_num_outputs(dummy_out): # user's responsibility to give functions that always have same number of outputs - if isinstance(dummy_out,tuple): + +# user's responsibility to give functions that always have same number of +# outputs +def get_num_outputs(dummy_out): + if isinstance(dummy_out, tuple): return len(dummy_out) return 1 + def small_str(v): - if isinstance(v,float): - return str(clean_val(v,3)) - if isinstance(v,bool): + if isinstance(v, float): + return str(clean_val(v, 3)) + if isinstance(v, bool): return "T" if v else "F" return str(v) -def get(vars_list,index): # index should be within range to access -# v._vals and if not absolutely should raise an error, as it will here -# by the attempted access +def get(vars_list, index): # index should be within range to access + # v._vals and if not absolutely should raise an error, as it will here + # by the attempted access res = deepcopy([v._vals[index] for v in vars_list]) return res - - - diff --git a/RASP_support/analyse.py b/RASP_support/analyse.py index ed6acce..a1f4477 100644 --- a/RASP_support/analyse.py +++ b/RASP_support/analyse.py @@ -1,83 +1,97 @@ -from FunctionalSupport import Unfinished, UnfinishedSequence, UnfinishedSelect, \ -guarded_contains, guarded_compare, zipmap, is_real_unfinished # need these for actually comparing sequences and not just making more sequences +from FunctionalSupport import Unfinished, UnfinishedSequence, \ + UnfinishedSelect, guarded_contains, guarded_compare, zipmap from collections import defaultdict, Counter from copy import copy + def UnfinishedFunc(f): - setattr(Unfinished,f.__name__,f) + setattr(Unfinished, f.__name__, f) @UnfinishedFunc def get_parent_sequences(self): - # for UnfinishedSequences, this should get just the tuple of sequences the aggregate is applied to, - # and I think in order (as the parents will only be a select and a sequencestuple, and the seqs in the - # sequencestuple will be added in order and the select will be removed in this function) - return [p for p in self.get_parents() if isinstance(p,UnfinishedSequence)] # i.e. drop the selects + # for UnfinishedSequences, this should get just the tuple of sequences the + # aggregate is applied to, and I think in order (as the parents will only + # be a select and a sequencestuple, and the seqs in the sequencestuple will + # be added in order and the select will be removed in this function) + + # i.e. drop the selects + return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)] + Unfinished._full_seq_parents = None + + @UnfinishedFunc def get_full_seq_parents(self): - if None is self._full_seq_parents: - self._full_seq_parents = [u for u in self.get_full_parents() \ - if isinstance(u,UnfinishedSequence)] + if self._full_seq_parents is None: + self._full_seq_parents = [u for u in self.get_full_parents() + if isinstance(u, UnfinishedSequence)] return copy(self._full_seq_parents) + @UnfinishedFunc def get_parent_select(self): - if not hasattr(self,"parent_select"): + if not hasattr(self, "parent_select"): real_parents = self.get_parents() - self.parent_select = next((s for s in real_parents if \ - isinstance(s,UnfinishedSelect)), None) + self.parent_select = next((s for s in real_parents if + isinstance(s, UnfinishedSelect)), None) return self.parent_select + @UnfinishedFunc -def set_analysis_parent_select(self,options): - # doesn't really need to be a function but feels clearer visually to have it - # out here so i can see this variable is being registered to the unfinisheds +def set_analysis_parent_select(self, options): + # doesn't really need to be a function but feels clearer visually to have + # it out here so i can see this variable is being registered to the + # unfinisheds if None is self.parent_select: self.analysis_parent_select = self.parent_select else: - self.analysis_parent_select = next((ps for ps in options if \ - ps.compare_string==self.get_parent_select().compare_string), None) - assert not None is self.analysis_parent_select, "parent options given to seq: "+self.name+" did not"+\ - "include anything equivalent to actual seq's parent"+\ - " select ("+self.get_parent_select().compare_string+")" + getps = (ps for ps in options if + ps.compare_string == self.get_parent_select().compare_string) + self.analysis_parent_select = next(getps, None) + errnote = "parent options given to seq: " + self.name + " did not " \ + + "include anything equivalent to actual seq's parent select (" \ + + self.get_parent_select().compare_string + ")" + assert self.analysis_parent_select is not None, errnote + def squeeze_selects(selects): compstrs = set([s.compare_string for s in selects]) if len(compstrs) == len(selects): return selects - return [next(s for s in selects if s.compare_string==cs) for cs in compstrs] + return [next(s for s in selects if s.compare_string == cs) + for cs in compstrs] + @UnfinishedFunc -def schedule(self,scheduler='best',remove_minors=False): -# recall attentions can be created on level 1 but still generate seqs on level 3 etc -# hence width is number of *seqs* with different attentions per level. +def schedule(self, scheduler='best', remove_minors=False): + # recall attentions can be created on level 1 but still generate seqs on + # level 3 etc hence width is number of *seqs* with different attentions per + # level. def choose_scheduler(scheduler): if scheduler == 'best': - return 'greedy' - # TODO: implement lastminute, maybe others, and choose narrowest + return 'greedy' + # TODO: implement lastminute, maybe others, and choose narrowest # result of all options return scheduler scheduler = choose_scheduler(scheduler) seq_layers = self.greedy_seq_scheduler() if scheduler == 'greedy' \ - else self.lastminute_seq_scheduler() + else self.lastminute_seq_scheduler() if remove_minors: for i in seq_layers: seq_layers[i] = [seq for seq in seq_layers[i] if not seq.is_minor] - num_layers = max(seq_layers.keys()) - def get_seqs_selects(seqs): # all the selects needed to compute a set of seqs all_selects = set(seq.get_parent_select() for seq in seqs) - all_selects -= set([None]) # some of the seqs may not have parent matches, + # some of the seqs may not have parent matches, # eg, indices. these will return None, which we don't want to count - return squeeze_selects(all_selects) # squeeze identical parents + all_selects -= set([None]) + return squeeze_selects(all_selects) # squeeze identical parents - - layer_selects = { i:get_seqs_selects(seq_layers[i]) for i in seq_layers } + layer_selects = {i: get_seqs_selects(seq_layers[i]) for i in seq_layers} # mark remaining parent select after squeeze for i in seq_layers: @@ -86,22 +100,26 @@ def get_seqs_selects(seqs): return seq_layers, layer_selects + @UnfinishedFunc def greedy_seq_scheduler(self): all_seqs = sorted(self.get_full_seq_parents(), - key = lambda seq:seq.creation_order_id) - # sorting in order of creation automatically sorts by order of in-layer - # dependencies (i.e. things got through feedforwards), makes prints clearer + key=lambda seq: seq.creation_order_id) + # sorting in order of creation automatically sorts by order of in-layer + # dependencies (i.e. things got through feedforwards), makes prints clearer # and eventually is helpful for drawcompflow - levels = defaultdict(lambda :[]) + levels = defaultdict(lambda: []) for seq in all_seqs: - levels[seq.min_poss_depth].append(seq) # schedule all seqs as early as possible + # schedule all seqs as early as possible + levels[seq.min_poss_depth].append(seq) return levels -Unfinished.max_poss_depth_for_seq = (None,None) +Unfinished.max_poss_depth_for_seq = (None, None) + + @UnfinishedFunc -def lastminute_for_seq(self,seq): +def lastminute_for_seq(self, seq): raise NotImplementedError @@ -112,168 +130,196 @@ def lastminute_seq_scheduler(self): @UnfinishedFunc def typestr(self): - if isinstance(self,UnfinishedSelect): + if isinstance(self, UnfinishedSelect): return "select" - elif isinstance(self,UnfinishedSequence): + elif isinstance(self, UnfinishedSequence): return "seq" else: return "internal" + @UnfinishedFunc -def width_and_depth(self,scheduler='greedy',loud=True,print_tree_too=False,remove_minors=False): - seq_layers, layer_selects = self.schedule(scheduler=scheduler,remove_minors=remove_minors) - widths = {i:len(layer_selects[i]) for i in layer_selects} +def width_and_depth(self, scheduler='greedy', loud=True, print_tree_too=False, + remove_minors=False): + seq_layers, layer_selects = self.schedule( + scheduler=scheduler, remove_minors=remove_minors) + widths = {i: len(layer_selects[i]) for i in layer_selects} n_layers = max(seq_layers.keys()) max_width = max(widths[i] for i in widths) if loud: - print("analysing unfinished",self.typestr()+":",self.name) - print("using scheduler:",scheduler) - print("num layers:",n_layers,"max width:",max_width) + print("analysing unfinished", self.typestr()+":", self.name) + print("using scheduler:", scheduler) + print("num layers:", n_layers, "max width:", max_width) print("width per layer:") - print("\n".join( str(i)+"\t: "+str(widths[i]) \ - for i in range(1,n_layers+1) )) - # start from 1 to skip layer 0, which has width 0 - # and is just the inputs (tokens and indices) + print("\n".join(str(i)+"\t: "+str(widths[i]) + for i in range(1, n_layers+1))) + # start from 1 to skip layer 0, which has width 0 + # and is just the inputs (tokens and indices) if print_tree_too: - def print_layer(i,d): - print(i,"\t:",", ".join(seq.name for seq in d[i])) + def print_layer(i, d): + print(i, "\t:", ", ".join(seq.name for seq in d[i])) print("==== seqs at each layer: ====") - [print_layer(i,seq_layers) for i in range(1,n_layers+1)] + [print_layer(i, seq_layers) for i in range(1, n_layers+1)] print("==== selects at each layer: ====") - [print_layer(i,layer_selects) for i in range(1,n_layers+1)] + [print_layer(i, layer_selects) for i in range(1, n_layers+1)] return n_layers, max_width, widths + @UnfinishedFunc -def schedule_comp_depth(self,d): +def schedule_comp_depth(self, d): self.scheduled_comp_depth = d + @UnfinishedFunc -def get_all_ancestor_heads_and_ffs(self,remove_minors=False): +def get_all_ancestor_heads_and_ffs(self, remove_minors=False): class Head: - def __init__(self,select,sequences,comp_depth): + def __init__(self, select, sequences, comp_depth): self.comp_depth = comp_depth self.name = str([m.name for m in sequences]) self.sequences = sequences self.select = select - seq_layers, layer_selects = self.schedule('best',remove_minors=remove_minors) + seq_layers, layer_selects = self.schedule( + 'best', remove_minors=remove_minors) all_ffs = [m for m in self.get_full_seq_parents() if m.from_zipmap] if remove_minors: all_ffs = [ff for ff in all_ffs if not ff.is_minor] - for i in seq_layers: for m in seq_layers[i]: - if guarded_contains(all_ffs,m): - m.schedule_comp_depth(i) # mark comp depths of the ffs... drawcompflow wants to know + if guarded_contains(all_ffs, m): + # mark comp depths of the ffs... drawcompflow wants to know + m.schedule_comp_depth(i) heads = [] for i in layer_selects: for s in layer_selects[i]: - seqs = [m for m in seq_layers[i] if m.analysis_parent_select==s] - heads.append(Head(s,seqs,i)) + seqs = [m for m in seq_layers[i] if m.analysis_parent_select == s] + heads.append(Head(s, seqs, i)) + + return heads, all_ffs - return heads,all_ffs @UnfinishedFunc -def set_display_name(self,display_name): +def set_display_name(self, display_name): self.display_name = display_name - # again just making it more visible??? that there's an attribute being set somewhere + # again just making it more visible??? that there's an attribute being set + # somewhere + @UnfinishedFunc -def make_display_names_for_all_parents(self,skip_minors=False): +def make_display_names_for_all_parents(self, skip_minors=False): all_unfs = self.get_full_parents() - all_seqs = [u for u in set(all_unfs) if isinstance(u,UnfinishedSequence)] - all_selects = [u for u in set(all_unfs) if isinstance(u,UnfinishedSelect)] + all_seqs = [u for u in set(all_unfs) if isinstance(u, UnfinishedSequence)] + all_selects = [u for u in set(all_unfs) if isinstance(u, UnfinishedSelect)] if skip_minors: num_orig = len(all_seqs) all_seqs = [seq for seq in all_seqs if not seq.is_minor] name_counts = Counter([m.name for m in all_seqs]) name_suff = Counter() - for m in sorted(all_seqs+all_selects,key=lambda u:u.creation_order_id): - # yes, even the non-seqs need display names, albeit for now only worry about repeats in the seqs - # and sort by creation order to get name suffixes with chronological (and so non-confusing) order - if name_counts[m.name]>1: + for m in sorted(all_seqs+all_selects, key=lambda u: u.creation_order_id): + # yes, even the non-seqs need display names, albeit for now only worry + # about repeats in the seqs and sort by creation order to get name + # suffixes with chronological (and so non-confusing) order + if name_counts[m.name] > 1: m.set_display_name(m.name+"_"+str(name_suff[m.name])) name_suff[m.name] += 1 else: m.set_display_name(m.name) + @UnfinishedFunc def note_if_seeker(self): - if not isinstance(self,UnfinishedSequence): + if not isinstance(self, UnfinishedSequence): return - if (not self.get_parent_sequences()) and (not None is self.get_parent_select()): - # no parent sequences, but yes parent select: this value is a function - # of only its parent select, i.e., a seeker (marks whether select found something or not) + if (not self.get_parent_sequences()) \ + and (self.get_parent_select() is not None): + # no parent sequences, but yes parent select: this value is a function + # of only its parent select, i.e., a seeker (marks whether select found + # something or not) self.is_seeker = True self.seeker_flag = self.elementwise_function() self.seeker_default = self._default else: self.is_seeker = False + @UnfinishedFunc def mark_all_ancestor_seekers(self): [u.note_if_seeker() for u in self.get_full_parents()] -Unfinished._full_descendants_for_seq = (None,None) + +Unfinished._full_descendants_for_seq = (None, None) + + @UnfinishedFunc -def descendants_towards_seq(self,seq): - if not guarded_compare(self._full_descendants_for_seq[0],seq): - +def descendants_towards_seq(self, seq): + if not guarded_compare(self._full_descendants_for_seq[0], seq): + relevant = seq.get_full_parents() - res = [r for r in relevant if guarded_contains(r.get_parents(),self)] + res = [r for r in relevant if guarded_contains(r.get_parents(), self)] - self._full_descendants_for_seq = (seq,res) + self._full_descendants_for_seq = (seq, res) return self._full_descendants_for_seq[1] - + + @UnfinishedFunc -def is_minor_comp_towards_seq(self,seq): - if not isinstance(self,UnfinishedSequence): - return False # selects are always important - if self.never_display: # priority: never over always +def is_minor_comp_towards_seq(self, seq): + if not isinstance(self, UnfinishedSequence): + return False # selects are always important + if self.never_display: # priority: never over always return True if self.always_display: if self.is_constant(): - print("displaying constant:",self.name) - return False - if self.is_constant(): # e.g. 1 or "a" etc, just stuff created around constants by REPL behind the scenes + print("displaying constant:", self.name) + return False + if self.is_constant(): # e.g. 1 or "a" etc, just stuff created around + # constants by REPL behind the scenes return True children = self.descendants_towards_seq(seq) - if len(children)>1: - return False # this sequence was used twice -> must have been actually - # named as a real variable in the code (and not part of some bunch of operators) - # -> make it visible in the comp flow too - if len(children)==0: - return not guarded_compare(self,seq) # if it's the seq itself then clearly - # we're very interested in it. if it has no children and isnt the seq then we're checking out - # a weird dangly unused leaf, we shouldn't reach such a scenario through any of functions - # we'll be using to call this one, but might as well make this function complete just in case - # we forget + if len(children) > 1: + return False # this sequence was used twice -> must have been actually + # named as a real variable in the code (and not part of some bunch of + # operators) -> make it visible in the comp flow too + if len(children) == 0: + # if it's the seq itself then clearly we're very interested in it. if + # it has no children and isnt the seq then we're checking out a weird + # dangly unused leaf, we shouldn't reach such a scenario through any of + # functions we'll be using to call this one, but might as well make + # this function complete just in case we forget + return not guarded_compare(self, seq) child = children[0] - if isinstance(child,UnfinishedSelect): - return False # this thing feeds directly into a select, lets make it visible - return (child.from_zipmap and self.from_zipmap) # obtained through zipmap and feeds - # directly into another zipmap: minor operation as part of something more complicated + if isinstance(child, UnfinishedSelect): + return False # this thing feeds directly into a select, lets make it + # visible + # obtained through zipmap and feeds directly into another zipmap: minor + # operation as part of something more complicated + return (child.from_zipmap and self.from_zipmap) + Unfinished.is_minor = False + + @UnfinishedFunc -def set_minor_for_seq(self,seq): # another func just to be very explicit about an attribute that's getting set +# another func just to be very explicit about an attribute that's getting set +def set_minor_for_seq(self, seq): self.is_minor = self.is_minor_comp_towards_seq(seq) - + + @UnfinishedFunc def mark_all_minor_ancestors(self): all_ancestors = self.get_full_parents() for a in all_ancestors: a.set_minor_for_seq(self) + @UnfinishedFunc -def get_nonminor_parents(self): # assumes have already marked the minor parents -# according to current interests. -# otherwise, may remain marked according to a different seq, or possibly all on default value -# (none are minor, all are important) +def get_nonminor_parents(self): # assumes have already marked the minor + # parents according to current interests. + # otherwise, may remain marked according to a different seq, or possibly + # all on default value (none are minor, all are important) potentials = self.get_parents() nonminors = [] while potentials: @@ -284,19 +330,23 @@ def get_nonminor_parents(self): # assumes have already marked the minor parents potentials.update(p.get_parents()) return set(nonminors) + @UnfinishedFunc def get_nonminor_parent_sequences(self): - return [p for p in self.get_nonminor_parents() if isinstance(p,UnfinishedSequence)] + return [p for p in self.get_nonminor_parents() + if isinstance(p, UnfinishedSequence)] + @UnfinishedFunc -def get_immediate_parent_sequences(self): # gets both minor and nonminor sequences - return [p for p in self.get_parents() if isinstance(p,UnfinishedSequence)] +# gets both minor and nonminor sequences +def get_immediate_parent_sequences(self): + return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)] + @UnfinishedFunc def pre_aggregate_comp(seq): vvars = seq.get_parent_sequences() - vreal = zipmap(vvars,seq.elementwise_function) - if isinstance(vreal,tuple): # equivalently, if seq.output_index >= 0: + vreal = zipmap(vvars, seq.elementwise_function) + if isinstance(vreal, tuple): # equivalently, if seq.output_index >= 0: vreal = vreal[seq.output_index] return vreal - diff --git a/RASP_support/make_operators.py b/RASP_support/make_operators.py index d6a3e66..b1a26f7 100644 --- a/RASP_support/make_operators.py +++ b/RASP_support/make_operators.py @@ -1,146 +1,151 @@ -# extend UnfinishedSequence with a bunch of operators, +# extend UnfinishedSequence with a bunch of operators, # provided the unary and binary ops. # make them fully named functions instead of lambdas, even though # it's more lines, because the debug prints are so much clearer # this way -def add_ops(Class,apply_unary_op,apply_binary_op): +def add_ops(Class, apply_unary_op, apply_binary_op): - def addsetname(f,opname,rev): + def addsetname(f, opname, rev): def f_with_setname(*a): - assert len(a) in [1,2] - if len(a)==2: - a0,a1 = a if not rev else (a[1],a[0]) - name0 = a0.name if hasattr(a0,"name") else str(a0) - name1 = a1.name if hasattr(a1,"name") else str(a1) - # a0/a1 might not be an seq, just having an op on it with an seq. + assert len(a) in [1, 2] + if len(a) == 2: + a0, a1 = a if not rev else (a[1], a[0]) + name0 = a0.name if hasattr(a0, "name") else str(a0) + name1 = a1.name if hasattr(a1, "name") else str(a1) + # a0/a1 might not be a seq, just having an op on it with a seq. name = name0 + " " + opname + " " + name1 - else: # len(a)==1 - name = opname + " " +a[0].name - name = "( "+name+" )" # probably going to be composed with more ops, so... + else: # len(a)==1 + name = opname + " " + a[0].name + # probably going to be composed with more ops, so... + name = "( " + name + " )" return f(*a).setname(name).allow_suppressing_display() - # seqs created as parts of long sequences of operators - # may be suppressed in display, the final name of the whole composition will be - # sufficiently informative. - # have to set always_display to false *after* the setname, because setname marks - # always_display as True (under assumption it is normally being called by the user, - # who must clearly be naming some variable they care about) + # seqs created as parts of long sequences of operators may be + # suppressed in display, the final name of the whole composition + # will be sufficiently informative. Have to set always_display to + # false *after* the setname, because setname marks always_display + # as True (under assumption it is normally being called by the + # user, who must clearly be naming some variable they care about) return f_with_setname - def listop(f,listing_name): - setattr(Class,listing_name,f) + def listop(f, listing_name): + setattr(Class, listing_name, f) - def addop(opname,rev=False): - return lambda f:listop(addsetname(f,opname,rev),f.__name__) + def addop(opname, rev=False): + return lambda f: listop(addsetname(f, opname, rev), f.__name__) @addop("==") - def __eq__(self,other): - return apply_binary_op(self,other,lambda a,b:a==b) - + def __eq__(self, other): + return apply_binary_op(self, other, lambda a, b: a == b) + @addop("!=") - def __ne__(self,other): - return apply_binary_op(self,other,lambda a,b:a!=b) - + def __ne__(self, other): + return apply_binary_op(self, other, lambda a, b: a != b) + @addop("<") - def __lt__(self,other): - return apply_binary_op(self,other,lambda a,b:a") - def __gt__(self,other): - return apply_binary_op(self,other,lambda a,b:a>b) - + def __gt__(self, other): + return apply_binary_op(self, other, lambda a, b: a > b) + @addop("<=") - def __le__(self,other): - return apply_binary_op(self,other,lambda a,b:a<=b) - + def __le__(self, other): + return apply_binary_op(self, other, lambda a, b: a <= b) + @addop(">=") - def __ge__(self,other): - return apply_binary_op(self,other,lambda a,b:a>=b) - + def __ge__(self, other): + return apply_binary_op(self, other, lambda a, b: a >= b) @addop("+") - def __add__(self,other): - return apply_binary_op(self,other,lambda a,b:a+b) - - @addop("+",True) - def __radd__(self,other): - return apply_binary_op(self,other,lambda a,b:b+a) - + def __add__(self, other): + return apply_binary_op(self, other, lambda a, b: a+b) + + @addop("+", True) + def __radd__(self, other): + return apply_binary_op(self, other, lambda a, b: b+a) + @addop("-") - def __sub__(self,other): - return apply_binary_op(self,other,lambda a,b:a-b) - - @addop("-",True) - def __rsub__(self,other): - return apply_binary_op(self,other,lambda a,b:b-a) - + def __sub__(self, other): + return apply_binary_op(self, other, lambda a, b: a-b) + + @addop("-", True) + def __rsub__(self, other): + return apply_binary_op(self, other, lambda a, b: b-a) + @addop("*") - def __mul__(self,other): - return apply_binary_op(self,other,lambda a,b:a*b) - - @addop("*",True) - def __rmul__(self,other): - return apply_binary_op(self,other,lambda a,b:b*a) - + def __mul__(self, other): + return apply_binary_op(self, other, lambda a, b: a*b) + + @addop("*", True) + def __rmul__(self, other): + return apply_binary_op(self, other, lambda a, b: b*a) + @addop("//") - def __floordiv__(self,other): - return apply_binary_op(self,other,lambda a,b:a//b) - - @addop("//",True) - def __rfloordiv__(self,other): - return apply_binary_op(self,other,lambda a,b:b//a) - + def __floordiv__(self, other): + return apply_binary_op(self, other, lambda a, b: a//b) + + @addop("//", True) + def __rfloordiv__(self, other): + return apply_binary_op(self, other, lambda a, b: b//a) + @addop("/") - def __truediv__(self,other): - return apply_binary_op(self,other,lambda a,b:a/b) - - @addop("/",True) - def __rtruediv__(self,other): - return apply_binary_op(self,other,lambda a,b:b/a) - + def __truediv__(self, other): + return apply_binary_op(self, other, lambda a, b: a/b) + + @addop("/", True) + def __rtruediv__(self, other): + return apply_binary_op(self, other, lambda a, b: b/a) + @addop("%") - def __mod__(self,other): - return apply_binary_op(self,other,lambda a,b:a%b) - - @addop("%",True) - def __rmod__(self,other): - return apply_binary_op(self,other,lambda a,b:b%a) - + def __mod__(self, other): + return apply_binary_op(self, other, lambda a, b: a % b) + + @addop("%", True) + def __rmod__(self, other): + return apply_binary_op(self, other, lambda a, b: b % a) + @addop("divmod") - def __divmod__(self,other): - return apply_binary_op(self,other,lambda a,b:divmod(a,b)) - - @addop("divmod",True) - def __rdivmod__(self,other): - return apply_binary_op(self,other,lambda a,b:divmod(b,a)) - + def __divmod__(self, other): + return apply_binary_op(self, other, lambda a, b: divmod(a, b)) + + @addop("divmod", True) + def __rdivmod__(self, other): + return apply_binary_op(self, other, lambda a, b: divmod(b, a)) + @addop("pow") - def __pow__(self,other): - return apply_binary_op(self,other,lambda a,b:pow(a,b)) - - @addop("pow",True) - def __rpow__(self,other): - return apply_binary_op(self,other,lambda a,b:pow(b,a)) - - # skipping and, or, xor, which are bitwise and dont implement 'and' and 'or' but rather & and | - # similarly skipping lshift, rshift cause who wants them + def __pow__(self, other): + return apply_binary_op(self, other, lambda a, b: pow(a, b)) + + @addop("pow", True) + def __rpow__(self, other): + return apply_binary_op(self, other, lambda a, b: pow(b, a)) + + # skipping and, or, xor, which are bitwise and dont implement 'and' and + # 'or' but rather & and |. + # similarly skipping lshift, rshift cause who wants them. # wish i had not, and, or primitives, but can accept that dont. - # if people really want to do 'not' they can do '==False' instead, can do a little macro for it in the other sugar file or whatever + # if people really want to do 'not' they can do '==False' instead, can do a + # little macro for it in the other sugar file or whatever @addop("+") def __pos__(self): - return apply_unary_op(self,lambda a:+a) + return apply_unary_op(self, lambda a: +a) @addop("-") def __neg__(self): - return apply_unary_op(self,lambda a:-a) + return apply_unary_op(self, lambda a: -a) - @addop("abs") + @addop("abs") def __abs__(self): - return apply_unary_op(self,abs) + return apply_unary_op(self, abs) @addop("round") - def __round__(self): # not sure if python will get upset if round doesnt return an actual int tbh... will have to check. - return apply_unary_op(self,round) - - # defining floor, ceil, trunc showed up funny (green instead of blue), gonna go ahead and avoid + # not sure if python will get upset if round doesnt return an actual int + # tbh... will have to check. + def __round__(self): + return apply_unary_op(self, round) + + # defining floor, ceil, trunc showed up funny (green instead of blue), + # gonna go ahead and avoid diff --git a/tests/make_tgts.py b/tests/make_tgts.py index c07f81c..c460797 100644 --- a/tests/make_tgts.py +++ b/tests/make_tgts.py @@ -15,38 +15,43 @@ curr_path_marker = "[current]" -REPL_path = "RASP_support/REPL.py" -rasplib_path = "RASP_support/rasplib.rasp" +REPL_PATH = "RASP_support/REPL.py" +RASPLIB_PATH = "RASP_support/rasplib.rasp" + def things_in_path(path): if not os.path.exists(path): return [] - return [p for p in os.listdir(path) if not p==".DS_Store"] + return [p for p in os.listdir(path) if not p == ".DS_Store"] + def joinpath(*a): return "/".join(a) -for p in [tgtpath,libtgtspath]: + +for p in [tgtpath, libtgtspath]: if not os.path.exists(p): os.makedirs(p) all_names = things_in_path(inpath) -def fix_file_paths(filename,curr_path_marker): - mypath = os.path.abspath(".") - with open(filename,"r") as f: +def fix_file_paths(filename, curr_path_marker): + mypath = os.path.abspath(".") + + with open(filename, "r") as f: filecontents = "".join(f) - filecontents = filecontents.replace(mypath,curr_path_marker) + filecontents = filecontents.replace(mypath, curr_path_marker) - with open(filename,"w") as f: - print(filecontents,file=f) + with open(filename, "w") as f: + print(filecontents, file=f) def run_input(name): - os.system("python3 "+REPL_path+" <"+inpath+"/"+name+" >"+tgtpath+"/"+name) - fix_file_paths(tgtpath+"/"+name,curr_path_marker) + os.system("python3 "+REPL_PATH+" <"+inpath+"/"+name+" >"+tgtpath+"/"+name) + fix_file_paths(tgtpath+"/"+name, curr_path_marker) + def run_inputs(): print("making the target outputs!") @@ -54,29 +59,34 @@ def run_inputs(): run_input(n) -def run_broken_lib(l): - os.system("cp "+joinpath(libspath,l)+" "+rasplib_path) - os.system("python3 "+REPL_path+" <"+joinpath(libtestspath,"empty.txt")+ " >"+joinpath(libtgtspath,l)) - +def run_broken_lib(lib): + os.system("cp "+joinpath(libspath, lib)+" "+RASPLIB_PATH) + os.system("python3 "+REPL_PATH+" <"+joinpath(libtestspath, + "empty.txt") + " >"+joinpath(libtgtspath, lib)) + real_rasplib_safe_place = "make_tgts_helper/temp" safe_rasplib_name = "safe_rasplib.rasp" - + + def save_rasplib(): if not os.path.exists(real_rasplib_safe_place): os.makedirs(real_rasplib_safe_place) - os.system("mv "+rasplib_path+" "+joinpath(real_rasplib_safe_place,safe_rasplib_name)) + os.system("mv "+RASPLIB_PATH+" " + + joinpath(real_rasplib_safe_place, safe_rasplib_name)) + def restore_rasplib(): - os.system("mv "+joinpath(real_rasplib_safe_place,safe_rasplib_name)+" "+rasplib_path) + os.system("mv "+joinpath(real_rasplib_safe_place, + safe_rasplib_name)+" "+RASPLIB_PATH) def run_broken_libs(): print("making the broken lib targets!") save_rasplib() all_libs = things_in_path(libspath) - for l in all_libs: - run_broken_lib(l) + for lib in all_libs: + run_broken_lib(lib) restore_rasplib() diff --git a/tests/test_all.py b/tests/test_all.py index e3ad4a8..8fe57e5 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,64 +1,68 @@ import os -from make_tgts import fix_file_paths, curr_path_marker, joinpath, things_in_path, \ -testpath, inpath, outpath, tgtpath, libtestspath, libspath, libtgtspath, liboutspath, \ -save_rasplib, restore_rasplib +from make_tgts import fix_file_paths, curr_path_marker, joinpath, \ + things_in_path, inpath, outpath, tgtpath, libtestspath, libspath, \ + libtgtspath, liboutspath, save_rasplib, restore_rasplib -def check_equal(f1,f2): +def check_equal(f1, f2): res = os.system("diff "+f1+" "+f2) - return res == 0 # 0 = diff found no differences + return res == 0 # 0 = diff found no differences -for p in [outpath,liboutspath]: +for p in [outpath, liboutspath]: if not os.path.exists(p): os.makedirs(p) + def run_input(name): - os.system("python3 RASP_support/REPL.py <"+joinpath(inpath,name)+" >"+joinpath(outpath,name)) - fix_file_paths(joinpath(outpath,name),curr_path_marker) - return check_equal(joinpath(outpath,name),joinpath(tgtpath,name)) + os.system("python3 RASP_support/REPL.py <" + + joinpath(inpath, name)+" >"+joinpath(outpath, name)) + fix_file_paths(joinpath(outpath, name), curr_path_marker) + return check_equal(joinpath(outpath, name), joinpath(tgtpath, name)) + def run_inputs(): all_names = things_in_path(inpath) passed = True for n in all_names: success = run_input(n) - print("input",n,"passed:",success) + print("input", n, "passed:", success) if not success: passed = False return passed -def test_broken_lib(l): - os.system("cp "+joinpath(libspath,l)+" RASP_support/rasplib.rasp") - os.system("python3 RASP_support/REPL.py <"+joinpath(libtestspath,"empty.txt")+ " >"+joinpath(liboutspath,l)) - return check_equal(joinpath(liboutspath,l),joinpath(libtgtspath,l)) + +def test_broken_lib(lib): + os.system("cp "+joinpath(libspath, lib)+" RASP_support/rasplib.rasp") + os.system("python3 RASP_support/REPL.py <"+joinpath(libtestspath, + "empty.txt") + " >"+joinpath(liboutspath, lib)) + return check_equal(joinpath(liboutspath, lib), joinpath(libtgtspath, lib)) + def run_broken_libs(): save_rasplib() all_libs = things_in_path(libspath) passed = True - for l in all_libs: - success = test_broken_lib(l) - print("lib",l,"passed (i.e., properly errored):",success) + for lib in all_libs: + success = test_broken_lib(lib) + print("lib", lib, "passed (i.e., properly errored):", success) if not success: passed = False restore_rasplib() return passed - - + if __name__ == "__main__": passed_inputs = run_inputs() - print("passed all inputs:",passed_inputs) + print("passed all inputs:", passed_inputs) print("=====\n\n=====") passed_broken_libs = run_broken_libs() - print("properly reports broken libs:",passed_broken_libs) + print("properly reports broken libs:", passed_broken_libs) print("=====\n\n=====") - - passed_everything = False not in [passed_inputs,passed_broken_libs] - print("=====\npassed everything:",passed_everything) + passed_everything = False not in [passed_inputs, passed_broken_libs] + print("=====\npassed everything:", passed_everything) if passed_everything: exit(0) else: - exit(1) \ No newline at end of file + exit(1)