Skip to content

Commit

Permalink
Make TidyImports + notebook integration idemponent
Browse files Browse the repository at this point in the history
Order of imports was not correct on the first click of the TidyImports
button. This PR fixes that issue.
  • Loading branch information
Divyansh Choudhary committed Oct 3, 2023
1 parent 288b28e commit 6d610b2
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions lib/python/pyflyby/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,23 @@ def _reformat_helper(input_code, imports):

def extract_import_statements(text):
"""This is a util for notebook interactions and extracts import statements
from some python code.
from some python code. This function also re-orders imports.
Args:
code (str): The code from which import statements have to be extracted
Returns:
(list[str], str): The first returned value is a list of all import
statements. The second returned value is the remaining code after
(str, str): The first returned value contains all the import statements.
The second returned value is the remaining code after
extracting the import statements.
"""
transformer = SourceToSourceFileImportsTransformation(text)
imports = [str(im.pretty_print()) for im in transformer.import_blocks]
imports = '\n'.join([str(im.pretty_print()) for im in transformer.import_blocks])
remaining_code = "\n".join([str(st.pretty_print()) if not isinstance(st, SourceToSourceImportBlockTransformation) else "" for st in transformer.blocks])
return imports, remaining_code

def collect_code_with_imports_on_top(imports, cell_array):
return (
"\n".join(imports)
def collect_code_with_imports_on_top(imports: str, cell_array):
return (
imports
+ "\n"
+ "\n".join(
[
Expand Down Expand Up @@ -185,7 +184,7 @@ def _recv(msg):
elif data["type"] == TIDY_IMPORTS:
checksum = data.get("checksum", '')
cell_array = data.get("cellArray", [])
import_statements, processed_cell_array = [], []
import_statements, processed_cell_array = "", []
for cell in cell_array:
text = cell.get("text")
cell_type = cell.get("type")
Expand All @@ -201,6 +200,6 @@ def _recv(msg):
"checksum": checksum,
"type": TIDY_IMPORTS,
"cells": processed_cell_array,
"imports": [im.strip() for im in import_statements],
"imports": import_statements,
}
)

0 comments on commit 6d610b2

Please sign in to comment.