Skip to content

Commit

Permalink
Merge pull request #1322 from mito-ds/custom-import-type-expansion
Browse files Browse the repository at this point in the history
Custom import type expansion
  • Loading branch information
aarondr77 authored Aug 29, 2024
2 parents 045a7a8 + 0b03109 commit 065b348
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 43 deletions.
2 changes: 1 addition & 1 deletion mitosheet/mitosheet/mito_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def sheet(
try_create_user_json_file()

try:

# Every Mitosheet has a different comm target, so they each create
# a different channel to communicate over
comm_target_id = get_new_id()
Expand Down
26 changes: 13 additions & 13 deletions mitosheet/mitosheet/public/v3/sheet_functions/string_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ def LEFT(string: StringRestrictedInputType, num_chars: Optional[IntRestrictedInp

# otherwise, turn them into series for simplicity
index = get_index_from_series(string, num_chars)
string = get_series_from_primitive_or_series(string, index).fillna('')
num_chars = get_series_from_primitive_or_series(num_chars, index).fillna(0)
string_series = get_series_from_primitive_or_series(string, index).fillna('')
num_chars_series = get_series_from_primitive_or_series(num_chars, index).fillna(0)

return pd.Series(
[left_helper(s, nc) for s, nc in zip(string, num_chars)],
index=string.index
[left_helper(s, nc) for s, nc in zip(string_series, num_chars_series)],
index=string_series.index
)


Expand Down Expand Up @@ -342,12 +342,12 @@ def RIGHT(string: StringRestrictedInputType, num_chars: Optional[IntRestrictedIn
return right_helper(string, num_chars)

index = get_index_from_series(string, num_chars)
string = get_series_from_primitive_or_series(string, index).fillna('')
num_chars = get_series_from_primitive_or_series(num_chars, index).fillna(0)
string_series = get_series_from_primitive_or_series(string, index).fillna('')
num_chars_series = get_series_from_primitive_or_series(num_chars, index).fillna(0)

return pd.Series(
[right_helper(s, nc) for s, nc in zip(string, num_chars)],
index=string.index
[right_helper(s, nc) for s, nc in zip(string_series, num_chars_series)],
index=string_series.index
)


Expand Down Expand Up @@ -395,13 +395,13 @@ def SUBSTITUTE(string: StringRestrictedInputType, old_text: StringRestrictedInpu
return string.replace(old_text, new_text, count)

index = get_index_from_series(string, old_text, new_text, count)
string = get_series_from_primitive_or_series(string, index).fillna('')
old_text = get_series_from_primitive_or_series(old_text, index).fillna('')
new_text = get_series_from_primitive_or_series(new_text, index).fillna('')
count = get_series_from_primitive_or_series(count, index).fillna(0)
string_series = get_series_from_primitive_or_series(string, index).fillna('')
old_text_series = get_series_from_primitive_or_series(old_text, index).fillna('')
new_text_series = get_series_from_primitive_or_series(new_text, index).fillna('')
count_series = get_series_from_primitive_or_series(count, index).fillna(0)

return pd.Series(
[s.replace(ot, nt, c) for s, ot, nt, c in zip(string, old_text, new_text, count)],
[s.replace(ot, nt, c) for s, ot, nt, c in zip(string_series, old_text_series, new_text_series, count_series)],
index=index
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def get_user_defined_function_param_type(f: Callable, param_name: str) -> UserDe
return 'DataFrame'
elif param_type == ColumnHeader:
return 'ColumnHeader'
elif param_type == List[int]:
return 'List[int]'
elif param_type == Dict[str, str]:
return 'Dict[str, str]'
else:
return 'any'

Expand Down Expand Up @@ -264,6 +268,12 @@ def get_user_defined_function_param_type_and_execute_value_and_transpile_value(
elif param_type == 'bool':
execute_value = 'true' in param_value.lower()
user_defined_function_params[param_name] = (param_type, execute_value, get_column_header_as_transpiled_code(execute_value))
elif param_type == 'List[int]':
execute_value = [int(value) for value in param_value.split(',')]
user_defined_function_params[param_name] = (param_type, execute_value, get_column_header_as_transpiled_code(execute_value))
elif param_type == 'Dict[str, str]':
execute_value = {key: value for key, value in (item.split(':') for item in param_value.split(','))}
user_defined_function_params[param_name] = (param_type, execute_value, get_column_header_as_transpiled_code(execute_value))
else:
try:
# If we don't know the type (it's untyped), we try and convert the value to a python object -- using an eval. If that fails,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_integration_success():
'type': 'success',
'config_options': {
'roles': ['NO_PYTEST_TABLE_ACCESS', 'READONLY'],
'warehouses': ['COMPUTE_WH'],
'warehouses': ['COMPUTE_WH', 'SYSTEM$STREAMLIT_NOTEBOOK_WH'],
'databases': ['PYTESTDATABASE', 'SNOWFLAKE', 'SNOWFLAKE_SAMPLE_DATA'],
'schemas': ['INFORMATION_SCHEMA', 'PYTESTSCHEMA'],
'tables_and_views': ['COLUMNHEADER_TEST', 'NOROWS', 'SIMPLE_PYTEST_TABLE', 'TYPETEST', 'TYPETEST_SIMPLE', 'YOUR_TABLE_NAME', 'SIMPLE_PYTEST_TABLE_VIEW'],
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_switch_roles_updates_defaults():
'type': 'success',
'config_options': {
'roles': ['NO_PYTEST_TABLE_ACCESS', 'READONLY'],
'warehouses': ['COMPUTE_WH'],
'warehouses': ['COMPUTE_WH', 'SYSTEM$STREAMLIT_NOTEBOOK_WH'],
'databases': ['PYTESTDATABASE', 'SNOWFLAKE', 'SNOWFLAKE_SAMPLE_DATA'],
'schemas': ['INFORMATION_SCHEMA', 'PYTESTSCHEMA'],
'tables_and_views': ['COLUMNHEADER_TEST', 'NOROWS', 'SIMPLE_PYTEST_TABLE', 'TYPETEST', 'TYPETEST_SIMPLE', 'YOUR_TABLE_NAME', 'SIMPLE_PYTEST_TABLE_VIEW'],
Expand Down Expand Up @@ -114,19 +114,19 @@ def test_switch_roles_updates_defaults():
'type': 'success',
'config_options': {
'roles': ['NO_PYTEST_TABLE_ACCESS', 'READONLY'],
'warehouses': [],
'warehouses': ['SYSTEM$STREAMLIT_NOTEBOOK_WH'],
'databases': ['SNOWFLAKE', 'SNOWFLAKE_SAMPLE_DATA'],
'schemas': ['ALERT', 'CORE', 'CORTEX', 'IMAGES', 'INFORMATION_SCHEMA', 'ML', 'NOTIFICATION'],
'tables_and_views': [],
'columns': []
},
'default_values': {
'role': 'NO_PYTEST_TABLE_ACCESS',
'warehouse': None,
'warehouse': 'SYSTEM$STREAMLIT_NOTEBOOK_WH',
'database': 'SNOWFLAKE',
'schema': 'ALERT',
'table_or_view': None,
}
}

assert expected_return == response
2 changes: 1 addition & 1 deletion mitosheet/mitosheet/tests/operators/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
('=A / 0', np.inf),
('=A / B / 0', np.inf),
('=A / 0 * 10', np.inf),
('=-1 * A / 0', np.NINF),
('=-1 * A / 0', -np.inf),
('=A / B / 0 * 10', np.inf),
]
@pytest.mark.parametrize("formula,mulitple", DIV_TESTS_VALID)
Expand Down
22 changes: 5 additions & 17 deletions mitosheet/mitosheet/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,7 @@ def test_transpile_fully_parameterized_function_string_no_df_names(tmp_path, df_
"",
f"txt = pd.read_csv(r'{tmp_file1}')",
"",
f"sheet_df_dictonary = "
f"pd.read_excel(r'{tmp_file2}', "
"engine='openpyxl', sheet_name=[\n"
f"{TAB}'Sheet1'{NEWLINE}"
'], skiprows=0)',
f"sheet_df_dictonary = pd.read_excel(r'{tmp_file2}', engine='openpyxl', sheet_name=['Sheet1'], skiprows=0)",
"Sheet1 = sheet_df_dictonary['Sheet1']",
"",
f"{expected_in_transpile}.to_csv(r'{tmp_exportfile1}', "
Expand All @@ -602,9 +598,7 @@ def test_transpile_fully_parameterized_function_string_no_df_names(tmp_path, df_
def function({expected_in_function}, file_name_import_csv_0, file_name_import_excel_0, file_name_export_csv_0, file_name_export_excel_0):
txt = pd.read_csv(file_name_import_csv_0)
sheet_df_dictonary = pd.read_excel(file_name_import_excel_0, engine='openpyxl', sheet_name=[
'Sheet1'
], skiprows=0)
sheet_df_dictonary = pd.read_excel(file_name_import_excel_0, engine='openpyxl', sheet_name=['Sheet1'], skiprows=0)
Sheet1 = sheet_df_dictonary['Sheet1']
{expected_in_function}.to_csv(file_name_export_csv_0, index=False)
Expand Down Expand Up @@ -682,9 +676,7 @@ def test_transpile_fully_parameterized_function_string_no_df_name_param(tmp_path
"",
f"sheet_df_dictonary = "
f"pd.read_excel(r'{tmp_file2}', "
"engine='openpyxl', sheet_name=[\n"
f"{TAB}'Sheet1'{NEWLINE}"
'], skiprows=0)',
"engine='openpyxl', sheet_name=['Sheet1'], skiprows=0)",
"Sheet1 = sheet_df_dictonary['Sheet1']",
"",
f"txt.to_csv(r'{tmp_exportfile1}', "
Expand All @@ -703,9 +695,7 @@ def test_transpile_fully_parameterized_function_string_no_df_name_param(tmp_path
def function(file_name_import_csv_0, file_name_import_excel_0, file_name_export_csv_0, file_name_export_excel_0):
txt = pd.read_csv(file_name_import_csv_0)
sheet_df_dictonary = pd.read_excel(file_name_import_excel_0, engine='openpyxl', sheet_name=[
'Sheet1'
], skiprows=0)
sheet_df_dictonary = pd.read_excel(file_name_import_excel_0, engine='openpyxl', sheet_name=['Sheet1'], skiprows=0)
Sheet1 = sheet_df_dictonary['Sheet1']
txt.to_csv(file_name_export_csv_0, index=False)
Expand Down Expand Up @@ -818,9 +808,7 @@ def test_transpile_parameterize_excel_imports(tmp_path):
"import pandas as pd",
"",
"def function(var_name):",
f"{TAB}sheet_df_dictonary = pd.read_excel(var_name, engine='openpyxl', sheet_name=[\n"
f"{TAB*2}'Sheet1'\n"
f"{TAB}], skiprows=0)",
f"{TAB}sheet_df_dictonary = pd.read_excel(var_name, engine='openpyxl', sheet_name=['Sheet1'], skiprows=0)",
f"{TAB}Sheet1 = sheet_df_dictonary['Sheet1']",
f'{TAB}',
f"{TAB}dataframe_1 = pd.read_excel(var_name, sheet_name='Sheet1', skiprows=0, nrows=1, usecols='A:B')",
Expand Down
11 changes: 9 additions & 2 deletions mitosheet/mitosheet/transpiler/transpile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ def get_column_header_as_transpiled_code(column_header: ColumnHeader, tab_level:
return f'({column_header_parts_joined})'
if isinstance(column_header, list):
column_header_parts = [get_column_header_as_transpiled_code(column_header_part, tab_level=tab_level+1) for column_header_part in column_header]
column_header_parts_joined = f',\n{TAB*(tab_level + 1)}'.join(column_header_parts)
return f'[\n{TAB*(tab_level + 1)}{column_header_parts_joined}\n{TAB*tab_level}]'

# Only add new lines in between entries if the full list would be too long
total_length_of_column_headers = sum([len(column_header_part) for column_header_part in column_header_parts])
if total_length_of_column_headers > 50:
column_header_parts_joined = f',\n{TAB*(tab_level + 1)}'.join(column_header_parts)
return f'[\n{TAB*(tab_level + 1)}{column_header_parts_joined}\n{TAB*tab_level}]'
else:
column_header_parts_joined = f','.join(column_header_parts)
return f'[{column_header_parts_joined}]'

# We must handle np.nan first because isinstance(np.nan, float) evaluates to True
if not is_prev_version(pd.__version__, '1.0.0') and column_header is np.nan:
Expand Down
2 changes: 1 addition & 1 deletion mitosheet/mitosheet/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ColumnDefinintion(TypedDict):

DefaultEditingMode = Literal['cell', 'column']

UserDefinedFunctionParamType = Literal['any', 'str', 'int', 'float', 'bool', 'DataFrame', 'ColumnHeader']
UserDefinedFunctionParamType = Literal['any', 'str', 'int', 'float', 'bool', 'DataFrame', 'ColumnHeader', 'List[int]', 'Dict[str, str]']

class MitoTheme(TypedDict):
primaryColor: str
Expand Down
2 changes: 1 addition & 1 deletion mitosheet/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_data_files_from_data_files_spec(
# According to this documentation (https://github.com/snowflakedb/snowflake-connector-python),
# snowflake-connect-python requires at least Python 3.7
'snowflake-connector-python[pandas]; python_version>="3.7"',
'streamlit>=1.24',
'streamlit>=1.24,<1.32',
'dash>=2.9',
"flask"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Toggle from "../../elements/Toggle";
import Tooltip from "../../elements/Tooltip";
import Col from '../../layout/Col';
import Row from '../../layout/Row';
import TextButton from "../../elements/TextButton";


/**
Expand Down Expand Up @@ -119,6 +120,7 @@ const UserDefinedFunctionParamConfigSection = (props: {
newParams[paramName] = newValue;
props.setParams(newParams);
}}
placeholder={paramType === 'str' ? 'Add string' : paramType === 'int' ? 'Add an int' : paramType === 'float' ? 'Add a float' : undefined}
/>
)
} else if (paramType === 'bool') {
Expand All @@ -133,6 +135,98 @@ const UserDefinedFunctionParamConfigSection = (props: {
}}
/>
)
} else if (paramType === 'List[int]') {
const paramValues = paramValue.split(',').map(value => value.trim());
if (paramValues.length === 0) {
paramValues.push('')
}
inputElement = (
<React.Fragment>
{paramValues.map((value, index) => {
return (
<Input
key={index}
value={value}
onChange={(e) => {
const newValues = [...paramValues];
newValues[index] = e.target.value;
const newParams = window.structuredClone(params);
newParams[paramName] = newValues.join(',');
props.setParams(newParams);
}}
style={{marginBottom: '8px'}}
placeholder="Add int to list"
/>
)
})}
<TextButton
onClick={() => {
const newValues = [...paramValues, ''];
const newParams = window.structuredClone(params);
newParams[paramName] = newValues.join(',');
props.setParams(newParams);
}}
variant="dark"
>
Add new int
</TextButton>
</React.Fragment>
)
} else if (paramType === 'Dict[str, str]') {
const paramValues = paramValue.split(',');
if (paramValues.length === 0) {
paramValues.push(":")
}
inputElement = (
<React.Fragment>
{paramValues.map((keyAndValue, index) => {
const [dictKey, dictValue] = keyAndValue.split(':');
return (
<Row key={index}>
<Input
key={'key' + index}
value={dictKey}
onChange={(e) => {
const newValues = [...paramValues];
const newDictKey = e.target.value;
newValues[index] = `${newDictKey || ''}:${dictValue || ''}`;
const newParams = window.structuredClone(params);
newParams[paramName] = newValues.join(',');
props.setParams(newParams);
}}
style={{marginBottom: '8px', marginRight: '2px'}}
placeholder="Add key"
/>
<Input
key={'value' + index}
value={dictValue}
onChange={(e) => {
const newValues = [...paramValues];
const newDictValue = e.target.value;
newValues[index] = `${dictKey || ''}:${newDictValue || ''}`;
const newParams = window.structuredClone(params);
newParams[paramName] = newValues.join(',');
props.setParams(newParams);
}}
style={{marginBottom: '8px'}}
placeholder="Add value"
/>
</Row>
)
})}
<TextButton
onClick={() => {
const newValues = [...paramValues, ":"];
const newParams = window.structuredClone(params);
newParams[paramName] = newValues.join(',');
props.setParams(newParams);
}}
variant="dark"
>
Add entry
</TextButton>
</React.Fragment>
)
}

const paramTypeDisplay = getParamTypeDisplay(paramType) !== undefined
Expand All @@ -142,7 +236,7 @@ const UserDefinedFunctionParamConfigSection = (props: {
const tooltip = `${paramName}${paramTypeDisplay}`;

paramRowElements.push(
<Row key={paramName} justify='space-between' align='center' title={tooltip}>
<Row key={paramName} justify='space-between' title={tooltip}>
<Col span={14}>
<Row justify="start" align="center" suppressTopBottomMargin>
<Col>
Expand Down
2 changes: 1 addition & 1 deletion mitosheet/src/mito/types.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ export enum MitoEnterpriseConfigKey {
export type PublicInterfaceVersion = 1 | 2 | 3;

type UserDefinedFunctionParamName = string;
export type UserDefinedFunctionParamType = 'any' | 'str' | 'int' | 'float' | 'bool' | 'DataFrame' | 'ColumnHeader';
export type UserDefinedFunctionParamType = 'any' | 'str' | 'int' | 'float' | 'bool' | 'List[int]' | 'Dict[str, str]' | 'DataFrame' | 'ColumnHeader';
export type UserDefinedFunctionParamNameToType = Record<UserDefinedFunctionParamName, UserDefinedFunctionParamType>

export type UserDefinedFunction = {
Expand Down

0 comments on commit 065b348

Please sign in to comment.