Skip to content

Commit

Permalink
write table and column config to db
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcmicu committed Nov 26, 2023
1 parent 0ea815f commit 4a4ea32
Showing 1 changed file with 105 additions and 16 deletions.
121 changes: 105 additions & 16 deletions scripts/guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from argparse import ArgumentParser
from lark import Lark
from numbers import Number
from pathlib import Path
from pprint import pformat
from textwrap import dedent


SPECIAL_TABLES = ["table", "column", "datatype", "rule", "history", "message"]
Expand Down Expand Up @@ -153,14 +155,14 @@ def get_higher_datatypes(datatype_hierarchies, universals, depth):

def get_sql_type(config, datatype):
"""Given the config map and the name of a datatype, climb the datatype tree (as required),
and return the first 'SQL type' found."""
and return the first 'SQLite type' found."""
if "datatype" not in config:
print("Missing datatypes in config")
sys.exit(1)
if datatype not in config["datatype"]:
return None
if config["datatype"][datatype].get("SQL type"):
return config["datatype"][datatype]["SQL type"]
if config["datatype"][datatype].get("SQLite type"):
return config["datatype"][datatype]["SQLite type"]
return get_sql_type(config, config["datatype"][datatype].get("parent"))


Expand Down Expand Up @@ -258,6 +260,10 @@ def is_match(datatype):
# If the datatype has no associated condition then it matches anything:
if not datatype.get("condition"):
return True
# If the SQLite type is NULL this datatype is ruled out:
sqlite_type = datatype.get("SQLite type")
if sqlite_type and sqlite_type.casefold() == "null":
return False

condition = get_compiled_condition(datatype["condition"], config["parser"])
num_values = len(target["values"])
Expand Down Expand Up @@ -372,7 +378,8 @@ def get_from(target, potential_foreign_columns):
"--error_rate",
type=float,
default=0.1,
help="Proportion of errors expected (default: 10%%)",
help="""A number between 0 and 1 (inclusive) representing the proportion of errors expected
(default: 0.1)""",
)
parser.add_argument(
"--enum_size",
Expand All @@ -383,6 +390,11 @@ def get_from(target, potential_foreign_columns):
parser.add_argument(
"--seed", type=int, help="Seed to use for random sampling (default: current epoch time)"
)
parser.add_argument(
"--yes",
action="store_true",
help="Do not ask for confirmation before writing suggested modifications to the database",
)
parser.add_argument(
"VALVE_TABLE", help="The VALVE table table from which to read the VALVE configuration"
)
Expand All @@ -407,29 +419,106 @@ def get_from(target, potential_foreign_columns):

# Get the valve configuration and database info:
config = get_valve_config(args.VALVE_TABLE)
if args.TABLE.removesuffix(".tsv") in config["table"]:
print(f"{args.TABLE.removesuffix('.tsv')} is already configured.", file=sys.stderr)
table_tsv = args.TABLE
table = Path(args.TABLE).stem
if table in config["table"]:
print(f"{table} is already configured.", file=sys.stderr)
sys.exit(0)
with sqlite3.connect(args.DATABASE) as conn:
config["db"] = conn

# Attach the condition parser to the config as well:
config["parser"] = Lark(grammar, parser="lalr", transformer=TreeToDict())

log(f"Getting random sample of {args.sample_size} rows from {args.TABLE} ...")
sample = get_random_sample(args.TABLE, args.sample_size)
log(f"Getting random sample of {args.sample_size} rows from {table_tsv} ...")
sample = get_random_sample(table_tsv, args.sample_size)
for i, label in enumerate(sample):
log(f"Annotating label '{label}' ...")
annotate(label, sample, config, args.error_rate, i == 0)
log("Done!")

# For debugging:
# pprint(sample)
table_table_headers = ["table", "path", "type", "description"]
column_table_headers = [
"table",
"column",
"label",
"nulltype",
"datatype",
"structure",
"description",
]
if not args.yes:
print()

print('The following row will be inserted to "table":')
data = [table_table_headers, [f"{table}", f"{table_tsv}", "", ""]]
# We add +2 for padding
col_width = max(len(word) for row in data for word in row) + 2
for row in data:
print("".join(word.ljust(col_width) for word in row))

print()

print('The following row will be inserted to "column":')
data = [column_table_headers]
for label in sample:
row = [
f"{table}",
f"{sample[label]['normalized']}",
f"{label}",
f"{sample[label].get('nulltype', '')}",
f"{sample[label]['datatype']}",
f"{sample[label].get('structure', '')}",
f"{sample[label].get('description', '')}",
]
data.append(row)
# We add +2 for padding
col_width = max(len(word) for row in data for word in row) + 2
for row in data:
print("".join(word.ljust(col_width) for word in row))

# For debugging without values:
for label in sample:
print(f"{label}: ", end="")
for annotation, data in sample[label].items():
if annotation != "values":
print(f"{annotation}: {data}, ", end="")
print()

answer = input("Do you want to write this updated configuration to the database? (y/n) ")
if answer.casefold() != "y":
print("Not writing updated configuration to the database.")
sys.exit(0)

log("Updating table configuration in database ...")
row_number = conn.execute('SELECT MAX(row_number) FROM "table"').fetchall()[0][0] + 1
sql = dedent(
f"""
INSERT INTO "table" ("row_number", {', '.join([f'"{k}"' for k in table_table_headers])})
VALUES ({row_number}, '{table}', '{table_tsv}', NULL, NULL)"""
)
log(sql, suppress_time=True)
log("", suppress_time=True)
conn.execute(sql)
conn.commit()

log("Updating column configuration in database ...")
row_number = conn.execute('SELECT MAX(row_number) FROM "column"').fetchall()[0][0] + 1
for label in sample:
values = ", ".join(
[
f"{row_number}",
f"'{table}'",
f"'{sample[label]['normalized']}'",
f"'{label}'",
f"'{sample[label]['nulltype']}'" if sample[label].get("nulltype") else "NULL",
f"'{sample[label]['datatype']}'",
f"'{sample[label]['structure']}'" if sample[label].get("structure") else "NULL",
f"'{sample[label]['description']}'" if sample[label].get("description") else "NULL",
]
)
sql = dedent(
f"""
INSERT INTO "column" ("row_number", {', '.join([f'"{k}"' for k in column_table_headers])})
VALUES ({values})"""
)
log(sql, suppress_time=True)
conn.execute(sql)
conn.commit()
row_number += 1
log("", suppress_time=True)
log("Done!")

0 comments on commit 4a4ea32

Please sign in to comment.