-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_test_data.py
88 lines (73 loc) · 2.86 KB
/
get_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from abc import ABC, abstractmethod
from typing import List
from models import TestCase, TestQuery
TESTS_DIRECTORY = 'test_queries'
def get_queries_and_cases():
test_queries = convert_parsed_paths_to_queries_and_cases(
parse_paths_from_os_walk(
os.walk(TESTS_DIRECTORY)),
file_content_provider=RealFileContentProvider())
return test_queries
class FileContentProvider(ABC):
@abstractmethod
def get_file_content(self, path: str) -> str: ...
class RealFileContentProvider(FileContentProvider):
def get_file_content(self, path: str) -> str:
with open(path, 'r') as f:
return f.read()
def convert_parsed_paths_to_queries_and_cases(
paths: dict,
file_content_provider: FileContentProvider) \
-> List[TestQuery]:
queries = []
for query_name, query_data in paths.items():
q = TestQuery(
name=query_name,
sql=file_content_provider.get_file_content(
query_data['sql_file']),
schema_set_up_command=file_content_provider.get_file_content(
query_data['schema_set_up_file']),
cases=[
TestCase(
name=case_name,
data_set_up_command=file_content_provider.get_file_content(
case_data['data_set_up_file']),
target_set_up_command=file_content_provider.get_file_content(
case_data['target_set_up_file']))
for case_name, case_data
in query_data['cases'].items()
]
)
queries.append(q)
return queries
def parse_paths_from_os_walk(os_walk_output):
result = {}
for root, _, files in os_walk_output:
root_split = root.split(os.path.sep)
if len(root_split) == 1:
continue
if len(root_split) == 2:
query_name = root_split[1]
result[query_name] = {'cases': {}}
for fp in files:
if fp == 'query.sql':
result[query_name]['sql_file'] = os.path.join(
root, fp)
elif fp == 'setUpSchema.sql':
result[query_name]['schema_set_up_file'] = os.path.join(
root, fp)
elif len(root_split) == 3:
query_name = root_split[1]
case_name = root_split[2]
result[query_name]['cases'][case_name] = {}
for fp in files:
if fp == 'setUpData.sql':
result[query_name]['cases'][case_name]['data_set_up_file'] = os.path.join(
root, fp)
elif fp == 'setUpTarget.sql':
result[query_name]['cases'][case_name]['target_set_up_file'] = os.path.join(
root, fp)
else:
raise AssertionError
return result