diff --git a/lib/python/pyflyby/_comms.py b/lib/python/pyflyby/_comms.py index 34d64cbe..646fe2a2 100644 --- a/lib/python/pyflyby/_comms.py +++ b/lib/python/pyflyby/_comms.py @@ -121,24 +121,23 @@ def _reformat_helper(input_code, imports): def extract_import_statements(text): """This is a util for notebook interactions and extracts import statements - from some python code. - + from some python code. This function also re-orders imports. Args: code (str): The code from which import statements have to be extracted Returns: - (list[str], str): The first returned value is a list of all import - statements. The second returned value is the remaining code after + (str, str): The first returned value contains all the import statements. + The second returned value is the remaining code after extracting the import statements. """ transformer = SourceToSourceFileImportsTransformation(text) - imports = [str(im.pretty_print()) for im in transformer.import_blocks] + imports = '\n'.join([str(im.pretty_print()) for im in transformer.import_blocks]) remaining_code = "\n".join([str(st.pretty_print()) if not isinstance(st, SourceToSourceImportBlockTransformation) else "" for st in transformer.blocks]) return imports, remaining_code -def collect_code_with_imports_on_top(imports, cell_array): - return ( - "\n".join(imports) +def collect_code_with_imports_on_top(imports: str, cell_array): + return ( + imports + "\n" + "\n".join( [ @@ -185,7 +184,13 @@ def _recv(msg): elif data["type"] == TIDY_IMPORTS: checksum = data.get("checksum", '') cell_array = data.get("cellArray", []) - import_statements, processed_cell_array = [], [] + # import_statements is a string because when + # SourceToSourceFileImportsTransformation is run on a piece of code + # it will club similar imports together and re-order the imports + # by making the imports a string, all the imports are processed + # together making sure tidy-imports has context on all the imports + # while clubbing similar imports and re-ordering them. + import_statements, processed_cell_array = "", [] for cell in cell_array: text = cell.get("text") cell_type = cell.get("type") @@ -201,6 +206,6 @@ def _recv(msg): "checksum": checksum, "type": TIDY_IMPORTS, "cells": processed_cell_array, - "imports": [im.strip() for im in import_statements], + "imports": import_statements, } )