Skip to content

Commit

Permalink
redesign algorithm for get_datatype()
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcmicu committed Nov 20, 2023
1 parent 2e12dda commit 318da29
Showing 1 changed file with 82 additions and 78 deletions.
160 changes: 82 additions & 78 deletions scripts/guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import sys
import time

from copy import deepcopy
from guess_grammar import grammar, TreeToDict

from argparse import ArgumentParser
from lark import Lark
from numbers import Number
from pprint import pformat


SPECIAL_TABLES = ["table", "column", "datatype", "rule", "history", "message"]
Expand Down Expand Up @@ -80,33 +82,62 @@ def get_parents(dt_name):
return [config["datatype"][primary_dt_name]] + get_parents(primary_dt_name)


def get_datatype_hierarchy(config):
def get_dt_hierarchies(config):
"""
Given a VALVE configuration, return a datatype hierarchy that looks like this:
{'dt_name_1': [{'datatype': 'dt_name_1',
'description': 'a description',
...},
{'datatype': 'parent datatype',
'description': 'a description',
...},
{'datatype': 'grandparent datatype',
'description': 'a description',
...},
...],
'dt_name_2': etc.
{0: {'dt_name_1': [{'datatype': 'dt_name_1',
'description': 'a description',
...},
{'datatype': 'parent datatype',
'description': 'a description',
...},
{'datatype': 'grandparent datatype',
'description': 'a description',
...},
...],
'dt_name_2': etc.},
1: ... etc.}
"""

def get_higher_datatypes(datatype_hierarchies, universals, depth):
current_datatypes = [dt_name for dt_name in datatype_hierarchies.get(depth, [])]
higher_datatypes = {}
if current_datatypes:
universals = [dt_name for dt_name in universals]
lower_datatypes = []
for i in range(0, depth):
lower_datatypes += [dt_name for dt_name in datatype_hierarchies.get(i, [])]
for dt_name in dt_hierarchies[depth]:
dt_hierarchy = dt_hierarchies[depth][dt_name]
if len(dt_hierarchy) > 1:
parent_hierarchy = dt_hierarchy[1:]
parent = parent_hierarchy[0]["datatype"]
if parent not in current_datatypes + lower_datatypes + universals:
higher_datatypes[parent] = parent_hierarchy
return higher_datatypes

dt_config = config["datatype"]
dt_names = [dt_name for dt_name in dt_config]
leaf_dts = []
for dt in dt_names:
children = [child for child in dt_names if dt_config[child].get("parent") == dt]
dt_hierarchies = {0: {}}
universals = {}
for dt_name in dt_names:
# Add all the leaf datatypes to dt_hierarchies at 0 depth:
children = [child for child in dt_names if dt_config[child].get("parent") == dt_name]
if not children:
leaf_dts.append(dt)

dt_hierarchy = {}
for leaf_dt in leaf_dts:
dt_hierarchy[leaf_dt] = get_hierarchy_for_dt(config, leaf_dt)
return dt_hierarchy
dt_hierarchies[0][dt_name] = get_hierarchy_for_dt(config, dt_name)
# Ungrounded and unconditioned datatypes go into the universals category, which are added
# to the top of dt_hierarchies later:
elif not dt_config[dt_name].get("parent") or not dt_config[dt_name].get("condition"):
universals[dt_name] = get_hierarchy_for_dt(config, dt_name)

depth = 0
higher_dts = get_higher_datatypes(dt_hierarchies, universals, depth)
while higher_dts:
depth += 1
dt_hierarchies[depth] = deepcopy(higher_dts)
higher_dts = get_higher_datatypes(dt_hierarchies, universals, depth)
dt_hierarchies[depth + 1] = universals
return dt_hierarchies


def get_sql_type(config, datatype):
Expand Down Expand Up @@ -211,17 +242,7 @@ def has_duplicates(target, ignore_empties):
distinct_values = set(values)
return (len(values) - len(distinct_values)) > (error_rate * len(values))

def get_datatype(target, dt_hierarchy):
# For each tree in the hierarchy:
# Look for a match with the 0th element and possibly add it to matching_datatypes.
# If there are matches in matching_datatypes:
# Use the tiebreaker rules to find the best match and annotate the target with it.
# Else:
# Try again with the next highest element of each tree (if one exists)
#
# Note that this is guaranteed to work since the get_datatype_hierarchy() function includes
# the 'text' datatype which matches anything. So if no matches are found raise an error.

def get_datatype(target, dt_hierarchies):
def is_match(datatype):
# If the datatype has no associated condition then it matches anything:
if not datatype.get("condition"):
Expand All @@ -235,61 +256,44 @@ def is_match(datatype):
return success_rate

def tiebreak(datatypes):
# TODO: There is a problem with this algorithm, since it implicitly assumes that if two
# datatypes are of the same depth, then neither can be a parent of the other. But this
# is false. We could have, for example,
# leaf_1 -> non_space -> trimmed_line
# leaf_2 -> word -> non_space -> trimmed_line
# Even though non-space is a parent of word, the algorithm classifies both as depth 1.
# We need to have another check in this function to determine whether there are any
# parent-child dependencies between the datatypes in the tiebreaker list.
in_types = []
other_types = []
parents = set([dt["datatype"].get("parent") for dt in datatypes])
parents.discard(None)
for dt in datatypes:
if dt["datatype"].get("condition", "").lstrip().startswith("in("):
in_types.append(dt)
else:
other_types.append(dt)
sorted_types = sorted(
in_types, key=lambda k: (k["depth"], k["success_rate"]), reverse=True
) + sorted(other_types, key=lambda k: (k["depth"], k["success_rate"]), reverse=True)
return sorted_types[0]["datatype"]

curr_index = 0
while True:
matching_datatypes = []
datatypes_to_check = []
for dt_name in dt_hierarchy:
if (
len(dt_hierarchy[dt_name]) > curr_index
and dt_hierarchy[dt_name][curr_index] not in datatypes_to_check
):
datatypes_to_check.append(dt_hierarchy[dt_name][curr_index])
if len(datatypes_to_check) == 0:
print(f"Could not find a datatype match for column '{label}'")
if dt["datatype"]["datatype"] not in parents:
if dt["datatype"].get("condition", "").lstrip().startswith("in("):
in_types.append(dt)
else:
other_types.append(dt)

if len(in_types) == 1:
return in_types[0]["datatype"]
elif len(in_types) > 1:
in_types = sorted(in_types, key=lambda k: k["success_rate"], reverse=True)
return in_types[0]["datatype"]
elif len(other_types) == 1:
return other_types[0]["datatype"]
elif len(other_types) > 1:
other_types = sorted(other_types, key=lambda k: k["success_rate"], reverse=True)
return other_types[0]["datatype"]
else:
print(f"Error tiebreaking datatypes: {pformat(datatypes)}")
sys.exit(1)

for depth in range(0, len(dt_hierarchies)):
datatypes_to_check = [dt_hierarchies[depth][dt][0] for dt in dt_hierarchies[depth]]
matching_datatypes = []
for datatype in datatypes_to_check:
success_rate = is_match(datatype)
if success_rate:
matching_datatypes.append(
{
"datatype": datatype,
"depth": curr_index,
"success_rate": success_rate,
}
)

if len(matching_datatypes) == 0:
curr_index += 1
continue
elif len(matching_datatypes) == 1:
matching_datatypes.append({"datatype": datatype, "success_rate": success_rate})

if len(matching_datatypes) == 1:
return matching_datatypes[0]["datatype"]
else:
elif len(matching_datatypes) > 1:
return tiebreak(matching_datatypes)

curr_index += 1

def get_from(target, potential_foreign_columns):
candidate_froms = []
for foreign in potential_foreign_columns:
Expand Down Expand Up @@ -322,8 +326,8 @@ def get_from(target, potential_foreign_columns):
target["nulltype"] = "empty"

# Use the valve config to retrieve the valve datatype hierarchy:
dt_hierarchy = get_datatype_hierarchy(config)
target["datatype"] = get_datatype(target, dt_hierarchy)["datatype"]
dt_hierarchies = get_dt_hierarchies(config)
target["datatype"] = get_datatype(target, dt_hierarchies)["datatype"]

# Use the valve config to get a list of columns already loaded to the database, then compare
# the contents of each column with the contents of the target column and possibly annotate the
Expand Down

0 comments on commit 318da29

Please sign in to comment.