diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce2f804..284a89c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,13 @@ repos: .*\.html )$ + # Format docstrings + - repo: https://github.com/DanielNoord/pydocstringformatter + rev: v0.7.3 + hooks: + - id: pydocstringformatter + args: ["--style=numpydoc"] + # Ruff, the Python auto-correcting linter/formatter written in Rust - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.4.7 diff --git a/acro/acro.py b/acro/acro.py index 8750a11..8bc9e3f 100644 --- a/acro/acro.py +++ b/acro/acro.py @@ -47,7 +47,7 @@ class ACRO(Tables, Regression): """ def __init__(self, config: str = "default", suppress: bool = False) -> None: - """Constructs a new ACRO object and reads parameters from config. + """Construct a new ACRO object and reads parameters from config. Parameters ---------- @@ -79,7 +79,7 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None: acro_tables.SURVIVAL_THRESHOLD = self.config["survival_safe_threshold"] def finalise(self, path: str = "outputs", ext="json") -> Records | None: - """Creates a results file for checking. + """Create a results file for checking. Parameters ---------- @@ -114,7 +114,7 @@ def finalise(self, path: str = "outputs", ext="json") -> Records | None: return self.results def remove_output(self, key: str) -> None: - """Removes an output from the results. + """Remove an output from the results. Parameters ---------- @@ -124,7 +124,7 @@ def remove_output(self, key: str) -> None: self.results.remove(key) def print_outputs(self) -> str: - """Prints the current results dictionary. + """Print the current results dictionary. Returns ------- @@ -134,7 +134,7 @@ def print_outputs(self) -> str: return self.results.print() def custom_output(self, filename: str, comment: str = "") -> None: - """Adds an unsupported output to the results dictionary. + """Add an unsupported output to the results dictionary. Parameters ---------- @@ -158,7 +158,7 @@ def rename_output(self, old: str, new: str) -> None: self.results.rename(old, new) def add_comments(self, output: str, comment: str) -> None: - """Adds a comment to an output. + """Add a comment to an output. Parameters ---------- @@ -170,7 +170,7 @@ def add_comments(self, output: str, comment: str) -> None: self.results.add_comments(output, comment) def add_exception(self, output: str, reason: str) -> None: - """Adds an exception request to an output. + """Add an exception request to an output. Parameters ---------- @@ -183,7 +183,7 @@ def add_exception(self, output: str, reason: str) -> None: def add_to_acro(src_path: str, dest_path: str = "sdc_results") -> None: - """Adds outputs to an acro object and creates a results file for checking. + """Add outputs to an acro object and creates a results file for checking. Parameters ---------- diff --git a/acro/acro_regression.py b/acro/acro_regression.py index db350f9..3abfec0 100644 --- a/acro/acro_regression.py +++ b/acro/acro_regression.py @@ -402,7 +402,7 @@ def __check_model_dof(self, name: str, model) -> tuple[str, str, float]: def get_summary_dataframes(results: list[SimpleTable]) -> list[DataFrame]: - """Converts a list of SimpleTable objects to a list of DataFrame objects. + """Convert a list of SimpleTable objects to a list of DataFrame objects. Parameters ---------- diff --git a/acro/acro_stata_parser.py b/acro/acro_stata_parser.py index d263413..4712267 100644 --- a/acro/acro_stata_parser.py +++ b/acro/acro_stata_parser.py @@ -1,5 +1,6 @@ """ -File with commands to manage the stata-acro interface +File with commands to manage the stata-acro interface. + Jim Smith 2023 @james.smith@uwe.ac.uk MIT licenses apply. """ @@ -14,10 +15,7 @@ def apply_stata_ifstmt(raw: str, all_data: pd.DataFrame) -> pd.DataFrame: - """ - Parses an if statement from stata format - then uses it to subset a dataframe by contents. - """ + """Parse an if statement from stata format then use it to subset a dataframe by contents.""" if len(raw) == 0: return all_data @@ -36,8 +34,9 @@ def apply_stata_ifstmt(raw: str, all_data: pd.DataFrame) -> pd.DataFrame: def parse_location_token(token: str, last: int) -> int: """ - Parses index position tokens from stata syntax - stata allows f and F for first item and l/L for last. + Parse index position tokens from stata syntax. + + Stata allows f and F for first item and l/L for last. """ lookup: dict = {"f": 0, "F": 0, "l": last, "L": last} if token in ["f", "F", "l", "L"]: @@ -54,10 +53,7 @@ def parse_location_token(token: str, last: int) -> int: def apply_stata_expstmt(raw: str, all_data: pd.DataFrame) -> pd.DataFrame: - """ - Parses an in exp statement from stata and uses it - to subset a dataframe by set of row indices. - """ + """Parse an in exp statement from stata and use it to subset a dataframe by row indices.""" last = len(all_data) - 1 if "/" not in raw: pos = parse_location_token(raw, last) @@ -86,11 +82,9 @@ def apply_stata_expstmt(raw: str, all_data: pd.DataFrame) -> pd.DataFrame: def find_brace_word(word: str, raw: str): - """ - Given a word followed by a ( - finds and returns as a list of strings - the rest of the contents up to the closing ). - first returned value is True/False depending on parsing ok. + """Return contents as a list of strings between '(' following a word and the closing ')'. + + First returned value is True/False depending on parsing ok. """ result = [] idx = raw.find(word) @@ -113,7 +107,7 @@ def find_brace_word(word: str, raw: str): def extract_aggfun_values_from_options(details, contents_found, content, varnames): - """Extracts the aggfunc and the values from the content.""" + """Extract the aggfunc and the values from the content.""" # contents can be variable names or aggregation functions details["aggfuncs"], details["values"] = list([]), list([]) if contents_found and len(content) > 0: @@ -132,7 +126,8 @@ def extract_aggfun_values_from_options(details, contents_found, content, varname def parse_table_details( varlist: list, varnames: list, options: str, stata_version: str ) -> dict: - """Function to parse stata-16 style table calls + """Parse stata-16 style table calls. + Note this is not for latest version of stata, syntax here: https://www.stata.com/manuals16/rtable.pdf >> table rowvar [colvar [supercolvar] [if] [in] [weight] [, options]. @@ -202,8 +197,9 @@ def parse_and_run( # pylint: disable=too-many-arguments,too-many-locals stata_version: str, ) -> pd.DataFrame: """ + Run the appropriate command on a pre-existing ACRO object stata_acro. + Takes a dataframe and the parsed stata command line. - Runs the appropriate command on a pre-existing ACRO object stata_acro Returns the result as a formatted string. """ # sanity checking @@ -248,7 +244,7 @@ def parse_and_run( # pylint: disable=too-many-arguments,too-many-locals def run_session_command(command: str, varlist: list) -> str: - """Runs session commands that are data-independent.""" + """Run session commands that are data-independent.""" outcome = "" if command == "init": @@ -285,8 +281,9 @@ def run_session_command(command: str, varlist: list) -> str: def run_output_command(command: str, varlist: list) -> str: - """Runs outcome-level commands - first element of varlist is output affected + """Run outcome-level commands. + + First element of varlist is output affected rest (if relevant) is string passed to command. """ outcome = "" @@ -324,9 +321,7 @@ def run_output_command(command: str, varlist: list) -> str: def extract_var_within_parentheses(input_string): - """Given a string, this function extracts the words within the first parentheses - from a string. - """ + """Extract the words within the first parentheses from a string.""" string = "" string_match = re.match(r"\((.*?)\)", input_string) if string_match: @@ -336,7 +331,7 @@ def extract_var_within_parentheses(input_string): def extract_var_before_parentheses(input_string): - """Given a string, this function extracts the words before the first parentheses.""" + """Extract the words before the first parentheses.""" string = "" string_match = re.match(r"^(.*?)\(", input_string) if string_match: @@ -346,7 +341,8 @@ def extract_var_before_parentheses(input_string): def extract_table_var(input_string): - """Given a string, this function extracts the words within the parentheses. + """Extract the words within the parentheses. + If there are no parentheses the string is returned. """ string = "" @@ -359,9 +355,9 @@ def extract_table_var(input_string): def extract_colstring_tablestring(input_string): - """Given a string, this function extracts the column and the tables - variables as a string. It goes through different options eg. whether - the column string is between paranthese or not. + """Extract the column and the tables variables as a string. + + It goes through different options eg. whether the column string is between paranthese or not. """ colstring = "" tablestring = "" @@ -382,9 +378,9 @@ def extract_colstring_tablestring(input_string): def extract_strings(input_string): - """Given a string, this function extracts the index, column and the tables - variables as a string. It goes through different options eg. whether - the index string is between paranthese or not. + """Extract the index, column and the tables variables as a string. + + It goes through different options eg. whether the index string is between paranthese or not. """ rowstring = "" colstring = "" @@ -412,11 +408,11 @@ def extract_strings(input_string): def creates_datasets(data, details): - """This function returns the full dataset if the tables parameter is empty. + """Return the full dataset if the tables parameter is empty. + Otherwise, it divides the dataset to small dataset each one is the dataset when the tables parameter is equal to one of it is unique values. """ - set_of_data = {"Total": data} msg = "" # if tables var parameter was assigned, each table will @@ -449,10 +445,7 @@ def run_table_command( # pylint: disable=too-many-arguments,too-many-locals options: str, stata_version: str, ) -> str: - """ - Converts a stata table command into an acro.crosstab - then returns a prettified versaion of the cross_tab dataframe. - """ + """Convert a stata table command into an acro.crosstab and return a prettified dataframe.""" weights_empty = len(weights) == 0 if not weights_empty: # pragma return f"weights not currently implemented for _{weights}_\n" @@ -534,7 +527,7 @@ def run_table_command( # pylint: disable=too-many-arguments,too-many-locals def run_regression(command: str, data: pd.DataFrame, varlist: list) -> str: - """Interprets and runs appropriate regression command.""" + """Interpret and run appropriate regression command.""" # get components of formula depvar = varlist[0] indep_vars = varlist[1:] @@ -562,7 +555,7 @@ def run_regression(command: str, data: pd.DataFrame, varlist: list) -> str: def get_regr_results(results: sm_iolib_summary.Summary, title: str) -> str: - """Translates statsmodels.io.summary object into prettified table.""" + """Translate statsmodels.io.summary object into prettified table.""" res_str = title + "\n" for table in acro_regression.get_summary_dataframes(results.summary().tables): res_str += prettify_table_string(table, separator=",") + "\n" diff --git a/acro/acro_tables.py b/acro/acro_tables.py index 9c67072..37204e4 100644 --- a/acro/acro_tables.py +++ b/acro/acro_tables.py @@ -32,7 +32,7 @@ def mode_aggfunc(values) -> Series: Returns ------- Series - The mode. If there are multiple modes, randomly selects and returns one of the modes. + The mode. If multiple modes, randomly selects and returns one of the modes. """ modes = values.mode() return secrets.choice(modes) @@ -86,8 +86,9 @@ def crosstab( # pylint: disable=too-many-arguments,too-many-locals normalize=False, show_suppressed=False, ) -> DataFrame: - """Compute a simple cross tabulation of two (or more) factors. By - default, computes a frequency table of the factors unless an array of + """Compute a simple cross tabulation of two (or more) factors. + + By default, computes a frequency table of the factors unless an array of values and an aggregation function are passed. To provide consistent behaviour with different aggregation functions, @@ -411,7 +412,7 @@ def surv_func( # pylint: disable=too-many-arguments,too-many-locals bw_factor=1.0, filename="kaplan-meier.png", ) -> DataFrame: - """Estimates the survival function. + """Estimate the survival function. Parameters ---------- @@ -490,7 +491,7 @@ def surv_func( # pylint: disable=too-many-arguments,too-many-locals def survival_table( # pylint: disable=too-many-arguments,too-many-locals self, survival_table, safe_table, status, sdc, command, summary, outcome ): - """Creates the survival table according to the status of suppressing.""" + """Create the survival table according to the status of suppressing.""" if self.suppress: survival_table = safe_table self.results.add( @@ -508,7 +509,7 @@ def survival_table( # pylint: disable=too-many-arguments,too-many-locals def survival_plot( # pylint: disable=too-many-arguments,too-many-locals self, survival_table, survival_func, filename, status, sdc, command, summary ): - """Creates the survival plot according to the status of suppressing.""" + """Create the survival plot according to the status of suppressing.""" if self.suppress: survival_table = rounded_survival_table(survival_table) plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0) @@ -570,7 +571,8 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals filename="histogram.png", **kwargs, ): - """Creates a histogram from a single column. + """Create a histogram from a single column. + The dataset and the column's name should be passed to the function as parameters. If more than one column is used the histogram will not be calculated. @@ -755,7 +757,7 @@ def create_crosstab_masks( # pylint: disable=too-many-arguments,too-many-locals dropna, normalize, ): - """Creates masks to specify the cells to suppress.""" + """Create masks to specify the cells to suppress.""" # suppression masks to apply based on the following checks masks: dict[str, DataFrame] = {} @@ -863,7 +865,7 @@ def create_crosstab_masks( # pylint: disable=too-many-arguments,too-many-locals def delete_empty_rows_columns(table: DataFrame) -> tuple[DataFrame, list[str]]: - """Deletes empty rows and columns from table. + """Delete empty rows and columns from table. Parameters ---------- @@ -902,7 +904,7 @@ def delete_empty_rows_columns(table: DataFrame) -> tuple[DataFrame, list[str]]: def rounded_survival_table(survival_table): - """Calculates the rounded surival function.""" + """Calculate the rounded surival function.""" death_censored = ( survival_table["num at risk"].shift(periods=1) - survival_table["num at risk"] ) @@ -946,8 +948,7 @@ def rounded_survival_table(survival_table): def get_aggfunc(aggfunc: str | None) -> str | Callable | None: - """Checks whether an aggregation function is allowed and returns the - appropriate function. + """Check whether an aggregation function is allowed and return the appropriate function. Parameters ---------- @@ -978,8 +979,7 @@ def get_aggfunc(aggfunc: str | None) -> str | Callable | None: def get_aggfuncs( aggfuncs: str | list[str] | None, ) -> str | Callable | list[str | Callable] | None: - """Checks whether a list of aggregation functions is allowed and returns - the appropriate functions. + """Check whether aggregation functions are allowed and return appropriate functions. Parameters ---------- @@ -1013,7 +1013,7 @@ def get_aggfuncs( def agg_negative(vals: Series) -> bool: - """Aggregation function that returns whether any values are negative. + """Return whether any values are negative. Parameters ---------- @@ -1029,7 +1029,7 @@ def agg_negative(vals: Series) -> bool: def agg_missing(vals: Series) -> bool: - """Aggregation function that returns whether any values are missing. + """Return whether any values are missing. Parameters ---------- @@ -1045,7 +1045,7 @@ def agg_missing(vals: Series) -> bool: def agg_p_percent(vals: Series) -> bool: - """Aggregation function that returns whether the p percent rule is violated. + """Return whether the p percent rule is violated. That is, the uncertainty (as a fraction) of the estimate that the second highest respondent can make of the highest value. Assuming there are n @@ -1075,8 +1075,7 @@ def agg_p_percent(vals: Series) -> bool: def agg_nk(vals: Series) -> bool: - """Aggregation function that returns whether the top n items account for - more than k percent of the total. + """Return whether the top n items account for more than k percent of the total. Parameters ---------- @@ -1097,8 +1096,7 @@ def agg_nk(vals: Series) -> bool: def agg_threshold(vals: Series) -> bool: - """Aggregation function that returns whether the number of contributors is - below a threshold. + """Return whether the number of contributors is below a threshold. Parameters ---------- @@ -1114,8 +1112,7 @@ def agg_threshold(vals: Series) -> bool: def agg_values_are_same(vals: Series) -> bool: - """Aggregation function that returns whether all observations having - the same value. + """Return whether all observations having the same value. Parameters ---------- @@ -1134,7 +1131,7 @@ def agg_values_are_same(vals: Series) -> bool: def apply_suppression( table: DataFrame, masks: dict[str, DataFrame] ) -> tuple[DataFrame, DataFrame]: - """Applies suppression to a table. + """Apply suppression to a table. Parameters ---------- @@ -1186,7 +1183,7 @@ def apply_suppression( def get_table_sdc(masks: dict[str, DataFrame], suppress: bool) -> dict: - """Returns the SDC dictionary using the suppression masks. + """Return the SDC dictionary using the suppression masks. Parameters ---------- @@ -1221,7 +1218,7 @@ def get_table_sdc(masks: dict[str, DataFrame], suppress: bool) -> dict: def get_summary(sdc: dict) -> tuple[str, str]: - """Returns the status and summary of the suppression masks. + """Return the status and summary of the suppression masks. Parameters ---------- @@ -1270,7 +1267,7 @@ def get_summary(sdc: dict) -> tuple[str, str]: def get_queries(masks, aggfunc) -> list[str]: - """Returns a list of the boolean conditions for each true cell in the suppression masks. + """Return a list of the boolean conditions for each true cell in the suppression masks. Parameters ---------- @@ -1348,7 +1345,7 @@ def get_queries(masks, aggfunc) -> list[str]: def create_dataframe(index, columns) -> DataFrame: - """Combining the index and columns in a dataframe and return the datframe. + """Combine the index and columns in a dataframe and return the dataframe. Parameters ---------- diff --git a/acro/record.py b/acro/record.py index c39a54b..b6732e2 100644 --- a/acro/record.py +++ b/acro/record.py @@ -20,7 +20,7 @@ def load_outcome(outcome: dict) -> DataFrame: - """Returns a DataFrame from an outcome dictionary. + """Return a DataFrame from an outcome dictionary. Parameters ---------- @@ -31,7 +31,7 @@ def load_outcome(outcome: dict) -> DataFrame: def load_output(path: str, output: list[str]) -> list[str] | list[DataFrame]: - """Returns a loaded output. + """Return a loaded output. Parameters ---------- @@ -102,7 +102,7 @@ def __init__( # pylint: disable=too-many-arguments output: list[str] | list[DataFrame], comments: list[str] | None = None, ) -> None: - """Constructs a new output record. + """Construct a new output record. Parameters ---------- @@ -142,7 +142,7 @@ def __init__( # pylint: disable=too-many-arguments self.timestamp: str = now.isoformat() def serialize_output(self, path: str = "outputs") -> list[str]: - """Serializes outputs. + """Serialize outputs. Parameters ---------- @@ -183,7 +183,7 @@ def serialize_output(self, path: str = "outputs") -> list[str]: return output def __str__(self) -> str: - """Returns a string representation of a record. + """Return a string representation of a record. Returns ------- @@ -210,7 +210,7 @@ class Records: """Stores data related to a collection of output records.""" def __init__(self) -> None: - """Constructs a new object for storing multiple records.""" + """Construct a new object for storing multiple records.""" self.results: dict[str, Record] = {} self.output_id: int = 0 @@ -226,7 +226,7 @@ def add( # pylint: disable=too-many-arguments output: list[str] | list[DataFrame], comments: list[str] | None = None, ) -> None: - """Adds an output to the results. + """Add an output to the results. Parameters ---------- @@ -266,7 +266,7 @@ def add( # pylint: disable=too-many-arguments logger.info("add(): %s", new.uid) def remove(self, key: str) -> None: - """Removes an output from the results. + """Remove an output from the results. Parameters ---------- @@ -279,7 +279,7 @@ def remove(self, key: str) -> None: logger.info("remove(): %s removed", key) def get(self, key: str) -> Record: - """Returns a specified output from the results. + """Return a specified output from the results. Parameters ---------- @@ -295,7 +295,7 @@ def get(self, key: str) -> Record: return self.results[key] def get_keys(self) -> list[str]: - """Returns the list of available output keys. + """Return the list of available output keys. Returns ------- @@ -306,7 +306,7 @@ def get_keys(self) -> list[str]: return list(self.results.keys()) def get_index(self, index: int) -> Record: - """Returns the output at the specified position. + """Return the output at the specified position. Parameters ---------- @@ -323,7 +323,7 @@ def get_index(self, index: int) -> Record: return self.results[key] def add_custom(self, filename: str, comment: str | None = None) -> None: - """Adds an unsupported output to the results dictionary. + """Add an unsupported output to the results dictionary. Parameters ---------- @@ -373,7 +373,7 @@ def rename(self, old: str, new: str) -> None: logger.info("rename_output(): %s renamed to %s", old, new) def add_comments(self, output: str, comment: str) -> None: - """Adds a comment to an output. + """Add a comment to an output. Parameters ---------- @@ -388,7 +388,7 @@ def add_comments(self, output: str, comment: str) -> None: logger.info("a comment was added to %s", output) def add_exception(self, output: str, reason: str) -> None: - """Adds an exception request to an output. + """Add an exception request to an output. Parameters ---------- @@ -403,7 +403,7 @@ def add_exception(self, output: str, reason: str) -> None: logger.info("exception request was added to %s", output) def print(self) -> str: - """Prints the current results. + """Print the current results. Returns ------- @@ -418,7 +418,7 @@ def print(self) -> str: return outputs def validate_outputs(self) -> None: - """Prompts researcher to complete any required fields.""" + """Prompt researcher to complete any required fields.""" for _, record in self.results.items(): if record.status != "pass" and record.exception == "": logger.info( @@ -431,7 +431,7 @@ def validate_outputs(self) -> None: record.exception = input("") def finalise(self, path: str, ext: str) -> None: - """Creates a results file for checking. + """Create a results file for checking. Parameters ---------- @@ -455,7 +455,7 @@ def finalise(self, path: str, ext: str) -> None: logger.info("outputs written to: %s", path) def finalise_json(self, path: str) -> None: - """Writes outputs to a JSON file. + """Write outputs to a JSON file. Parameters ---------- @@ -494,7 +494,7 @@ def finalise_json(self, path: str) -> None: ) def finalise_excel(self, path: str) -> None: - """Writes outputs to an excel spreadsheet. + """Write outputs to an excel spreadsheet. Parameters ---------- @@ -543,7 +543,7 @@ def finalise_excel(self, path: str) -> None: table.to_excel(writer, sheet_name=output_id, startrow=start) def write_checksums(self, path: str) -> None: - """Writes checksums for each file to checksums folder. + """Write checksums for each file to checksums folder. Parameters ---------- @@ -569,7 +569,7 @@ def write_checksums(self, path: str) -> None: def load_records(path: str) -> Records: - """Loads outputs from a JSON file. + """Load outputs from a JSON file. Parameters ---------- diff --git a/acro/stata_config.py b/acro/stata_config.py index 468307d..479b2b1 100644 --- a/acro/stata_config.py +++ b/acro/stata_config.py @@ -1,7 +1,9 @@ """ -Config file to hold global variable for acro object -accessible from acro files and stata -mutable hence use of lower case naming +Stata config file. + +Holds global variable for acro object accessible from acro files and stata +mutable hence use of lower case naming. + Jim Smith 2023. """ diff --git a/acro/utils.py b/acro/utils.py index 47b916c..d576eaa 100644 --- a/acro/utils.py +++ b/acro/utils.py @@ -11,7 +11,7 @@ def get_command(default: str, stack_list: list[FrameInfo]) -> str: - """Returns the calling source line as a string. + """Return the calling source line as a string. Parameters ---------- @@ -38,10 +38,9 @@ def get_command(default: str, stack_list: list[FrameInfo]) -> str: def prettify_table_string(table: pd.DataFrame, separator: str | None = None) -> str: """ - Adds delimiters to table.to_string() - to improve readability for onscreen display. - Splits fields on whitespace unless an optional separator is provided - e.g. ',' for csv. + Add delimiters to table.to_string() to improve readability for onscreen display. + + Splits fields on whitespace unless an optional separator is provided e.g. ',' for csv. """ hdelim = "-" vdelim = "|" diff --git a/docs/source/conf.py b/docs/source/conf.py index 2deb6f9..77bb166 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,5 @@ -# Configuration file for the Sphinx documentation builder. -# +"""Configuration file for the Sphinx documentation builder.""" + # -- Path setup -------------------------------------------------------------- import os diff --git a/notebooks/acro_demo.py b/notebooks/acro_demo.py index 72e28d5..0c8630f 100644 --- a/notebooks/acro_demo.py +++ b/notebooks/acro_demo.py @@ -1,5 +1,6 @@ """ -ACRO Tests +ACRO Tests. + Copyright : Maha Albashir, Richard Preen, Jim Smith 2023. """ diff --git a/notebooks/test-nursery.py b/notebooks/test-nursery.py index bc29402..2bb18fd 100644 --- a/notebooks/test-nursery.py +++ b/notebooks/test-nursery.py @@ -1,5 +1,6 @@ """ -ACRO Tests +ACRO Tests. + Copyright : Maha Albashir, Richard Preen, Jim Smith 2023. """ diff --git a/pyproject.toml b/pyproject.toml index 5443420..67e659c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ lint.select = [ "B", # flake8-bugbear # "C4", # flake8-comprehensions "C90", # mccabe -# "D", # pydocstyle + "D", # pydocstyle # "DTZ", # flake8-datetimez # "E", # pycodestyle # "ERA", # eradicate @@ -42,8 +42,6 @@ exclude = [ lint.ignore = [ # "ANN101", # missing-type-self -# "D203", # blank line required before class docstring -# "D213", # multi-line-summary-second-line # "S301", # unsafe pickle ] diff --git a/test/test_initial.py b/test/test_initial.py index 058d378..28df518 100644 --- a/test/test_initial.py +++ b/test/test_initial.py @@ -1,4 +1,4 @@ -"""This module contains unit tests.""" +"""Unit tests.""" import json import os @@ -12,7 +12,7 @@ from acro import ACRO, acro_tables, add_constant, add_to_acro, record, utils from acro.record import Records, load_records -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name,too-many-lines PATH: str = "RES_PYTEST" @@ -485,9 +485,7 @@ def test_adding_exception(acro): def test_add_to_acro(data, monkeypatch): - """Adding an output that was generated without using acro to an acro object and - creates a results file for checking. - """ + """Add an output generated without acro to an acro object and create results file.""" # create a cross tabulation using pandas table = pd.crosstab(data.year, data.grant_type) # save the output to a file and add this file to a directory @@ -599,7 +597,6 @@ def test_hierachical_aggregation(data, acro): def test_single_values_column(data, acro): """Pandas does not allows multiple arrays for values.""" - with pytest.raises(ValueError): _ = acro.crosstab( data.year, @@ -634,9 +631,7 @@ def test_surv_func(acro): def test_zeros_are_not_disclosive(data, acro): - """Test that zeros are handled as not disclosive when - the parameter (zeros_are_disclosive) is False. - """ + """Test that zeros are handled as not disclosive when `zeros_are_disclosive=False`.""" acro_tables.ZEROS_ARE_DISCLOSIVE = False _ = acro.pivot_table( data, @@ -714,9 +709,7 @@ def test_crosstab_with_totals_with_suppression_with_mean(data, acro): def test_crosstab_with_totals_and_empty_data(data, acro, caplog): - """Test the crosstab with both margins and suppression are true - and with a dataset that all its data violate one or more rules. - """ + """Test crosstab when both margins and suppression are true with a disclosive dataset.""" data = data[ (data.year == 2010) & (data.grant_type == "G") @@ -736,9 +729,7 @@ def test_crosstab_with_totals_and_empty_data(data, acro, caplog): def test_crosstab_with_manual_totals_with_suppression(data, acro): - """Test the crosstab with both margins and - suppression are true while using the total manual function. - """ + """Test crosstab when margins and suppression are true with the total manual function.""" _ = acro.crosstab(data.year, data.grant_type, margins=True, show_suppressed=True) output = acro.results.get_index(0) assert 145 == output.output[0]["All"].iat[0] @@ -751,8 +742,9 @@ def test_crosstab_with_manual_totals_with_suppression(data, acro): def test_crosstab_with_manual_totals_with_suppression_hierarchical(data, acro): - """Test the crosstab with both margins and suppression - are true with multilevel indexes and columns while using the total manual function. + """Test crosstab when margins and suppression are true with hierarchical data. + + Tests with multilevel indexes and columns while using the total manual function. """ _ = acro.crosstab( [data.year, data.survivor], @@ -775,8 +767,10 @@ def test_crosstab_with_manual_totals_with_suppression_hierarchical(data, acro): def test_crosstab_with_manual_totals_with_suppression_with_aggfunc_mean(data, acro): - """Test the crosstab with both margins and suppression are true - and with aggfunc mean while using the total manual function. + """Test crosstab. + + Tests the crosstab with both margins and suppression are true and with + aggfunc mean while using the total manual function. """ _ = acro.crosstab( data.year, @@ -794,8 +788,11 @@ def test_crosstab_with_manual_totals_with_suppression_with_aggfunc_mean(data, ac def test_hierarchical_crosstab_with_manual_totals_with_mean(data, acro): - """Test the crosstab with both margins and suppression are true, with - aggfunc mean and with multilevel columns and rows while using the total manual function. + """Test crosstab. + + Test the crosstab with both margins and suppression are true, with aggfunc + mean and with multilevel columns and rows while using the total manual + function. """ _ = acro.crosstab( [data.year, data.survivor], @@ -815,7 +812,9 @@ def test_hierarchical_crosstab_with_manual_totals_with_mean(data, acro): def test_crosstab_with_manual_totals_with_suppression_with_aggfunc_std( data, acro, caplog ): - """Test the crosstab with both margins and suppression are true and with + """Test crosstab. + + Test the crosstab with both margins and suppression are true and with aggfunc std while using the total manual function. """ _ = acro.crosstab( @@ -882,8 +881,10 @@ def test_crosstab_multiple_aggregate_function(data, acro): def test_crosstab_with_totals_with_suppression_with_two_aggfuncs(data, acro): - """Test the crosstab with both margins and suppression are true - and with a list of aggfuncs while using the total manual function. + """Test crosstab. + + Test the crosstab with both margins and suppression are true and with a + list of aggfuncs while using the total manual function. """ _ = acro.crosstab( data.year, @@ -918,9 +919,11 @@ def test_crosstab_with_totals_with_suppression_with_two_aggfuncs(data, acro): def test_crosstab_with_totals_with_suppression_with_two_aggfuncs_hierarchical( data, acro ): - """Test the crosstab with both margins and suppression are true - and with a list of aggfuncs and a list of columns while using - the total manual function. + """Test crosstab. + + Test the crosstab with both margins and suppression are true and with a + list of aggfuncs and a list of columns while using the total manual + function. """ _ = acro.crosstab( data.year, @@ -937,8 +940,10 @@ def test_crosstab_with_totals_with_suppression_with_two_aggfuncs_hierarchical( def test_crosstab_with_manual_totals_with_suppression_with_two_aggfunc( data, acro, caplog ): - """Test the crosstab with both margins and suppression are true - and with a list of aggfuncs while using the total manual function. + """Test crosstab. + + Test the crosstab with both margins and suppression are true and with a + list of aggfuncs while using the total manual function. """ _ = acro.crosstab( data.year, @@ -985,7 +990,7 @@ def test_histogram_non_disclosive(data, acro): def test_finalise_with_existing_path(data, acro, caplog): - """Using a path that already exists when finalising.""" + """Test using a path that already exists when finalising.""" _ = acro.crosstab(data.year, data.grant_type) acro.add_exception("output_0", "Let me have it") acro.finalise(PATH) diff --git a/test/test_stata17_interface.py b/test/test_stata17_interface.py index 91b910c..729b196 100644 --- a/test/test_stata17_interface.py +++ b/test/test_stata17_interface.py @@ -1,4 +1,4 @@ -"""This module contains unit tests for the stata 17 interface.""" +"""Unit tests for the stata 17 interface.""" # The pylint skip file is to skip the error of R0801: Similar lines in 2 files. As the # file this file and the file test_stata_interface.py have a lot of similarities. @@ -34,7 +34,7 @@ def data() -> pd.DataFrame: def clean_up(name): - """Removes unwanted files or directory.""" + """Remove unwanted files or directory.""" if os.path.exists(name): if os.path.isfile(name): os.remove(name) @@ -46,7 +46,8 @@ def dummy_acrohandler( data, command, varlist, exclusion, exp, weights, options, stata_version ): # pylint:disable=too-many-arguments """ - Provides an alternative interface that mimics the code in acro.ado + Provide an alternative interface that mimics the code in acro.ado. + Most notably the presence of a global variable called stata_acro. """ acro_outstr = parse_and_run( @@ -58,11 +59,7 @@ def dummy_acrohandler( # --- Helper functions----------------------------------------------------- def test_find_brace_word(): - """Tests helper function - that extracts contents 'A B C' - of something specified via X(A B C) - on the stata command line. - """ + """Test helper function that extracts contents 'A B C' of something specified via X(A B C) on the stata command line.""" options = "statistic(mean inc_activity) suppress nototals" res, substr = find_brace_word("statistic", options) assert res @@ -78,11 +75,7 @@ def test_find_brace_word(): def test_parse_table_details(data): - """ - Series of checks that the varlist and options are parsed correctly - by the helper function. - """ - + """Check that the varlist and options are parsed correctly by the helper function.""" varlist = ["survivor", "grant_type", "year"] varnames = data.columns options = "statistic(mean inc_activity) suppress nototals" @@ -121,7 +114,8 @@ def test_parse_table_details(data): def test_stata_acro_init(): """ - Tests creation of an acro object at the start of a session + Test creation of an acro object at the start of a session. + For stata this gets held in a variable stata_acro Which is initialsied to the string "empty" in the acro.ado file Then should be pointed at a new acro instance. @@ -145,7 +139,7 @@ def test_stata_acro_init(): def test_stata_print_outputs(data): - """Checks print_outputs gets called.""" + """Check print_outputs gets called.""" ret = dummy_acrohandler( data, command="print_outputs", @@ -162,7 +156,8 @@ def test_stata_print_outputs(data): # ----main SDC functionality------------------------------------- def test_simple_table(data): """ - Checks that the simple table command works as expected + Check that the simple table command works as expected. + Does via reference to direct call to pd.crosstab() To make sure table specification is parsed correctly acro SDC analysis is tested elsewhere. @@ -193,8 +188,9 @@ def test_simple_table(data): def test_stata_rename_outputs(): - """Tests renaming outputs - assumes simple table has been created by earlier tests. + """Test renaming outputs. + + Assumes simple table has been created by earlier tests. """ the_str = "renamed_output" the_output = "output_0" @@ -215,8 +211,9 @@ def test_stata_rename_outputs(): def test_stata_incomplete_output_commands(): - """Tests handling incomplete or wrong output commands - assumes simple table has been created by earlier tests. + """Test handling incomplete or wrong output commands. + + Assumes simple table has been created by earlier tests. """ # output to change not provided the_str = "renamed_output" @@ -272,9 +269,9 @@ def test_stata_incomplete_output_commands(): def test_stata_add_comments(): """ - Tests adding comments to outputs - assumes simple table has been created by earlier tests - then renamed. + Test adding comments to outputs. + + Assumes simple table has been created by earlier tests then renamed. """ the_str = "some comments" the_output = "renamed_output" @@ -296,9 +293,9 @@ def test_stata_add_comments(): def test_stata_add_exception(): """ - Tests adding exception to outputs - assumes simple table has been created by earlier tests - then renamed. + Test adding exception to outputs. + + Assumes simple table has been created by earlier tests then renamed. """ the_str = "a reason" the_output = "renamed_output" @@ -320,8 +317,9 @@ def test_stata_add_exception(): def test_stata_remove_output(): """ - Tests removing outputs - assumes simple table has been created and renamed by earlier tests. + Tests removing outputs. + + Assumes simple table has been created and renamed by earlier tests. """ the_output = "renamed_output" ret = dummy_acrohandler( @@ -344,7 +342,7 @@ def test_stata_remove_output(): def test_stata_exclusion_in_context(data): - """Tests that the subsetting code gets called properly from table handler.""" + """Test that the subsetting code gets called properly from table handler.""" # if condition correct1 = ( "Total\n" @@ -422,7 +420,7 @@ def test_stata_exclusion_in_context(data): def test_table_weights(data): - """Weights are not currently supported.""" + """Test that weights are not currently supported.""" weights = [0, 0, 0] correct = f"weights not currently implemented for _{weights}_\n" ret = dummy_acrohandler( @@ -439,7 +437,7 @@ def test_table_weights(data): def test_table_aggcfn(data): - """Testing behaviour with aggregation function.""" + """Test behaviour with aggregation function.""" # ok correct = ( "Total\n" @@ -517,7 +515,7 @@ def test_table_aggcfn(data): def test_table_aggcfns(data): - """Testing behaviour with two aggregation functions.""" + """Test behaviour with two aggregation functions.""" correct = ( "Total\n" "------------------------------------------|\n" @@ -544,7 +542,7 @@ def test_table_aggcfns(data): def test_stata_probit(data): - """Checks probit gets called correctly.""" + """Check probit gets called correctly.""" ret = dummy_acrohandler( data, command="probit", @@ -570,7 +568,7 @@ def test_stata_probit(data): def test_stata_linregress(data): - """Checks linear regression called correctly.""" + """Check linear regression called correctly.""" ret = dummy_acrohandler( data, command="regress", @@ -591,7 +589,7 @@ def test_stata_linregress(data): def test_stata_logit(data): - """Tests stata logit function.""" + """Test stata logit function.""" ret = dummy_acrohandler( data, command="logit", @@ -618,7 +616,7 @@ def test_stata_logit(data): def test_unsupported_formatting_options(data): - """Checks that user gets warning if they try to format table.""" + """Check that user gets warning if they try to format table.""" format_string = "acro does not currently support table formatting commands." correct = ( "Total\n" @@ -661,7 +659,7 @@ def test_unsupported_formatting_options(data): def test_stata_finalise(monkeypatch): - """Checks finalise gets called correctly.""" + """Check finalise gets called correctly.""" monkeypatch.setattr("builtins.input", lambda _: "Let me have it") ret = dummy_acrohandler( data, @@ -678,7 +676,7 @@ def test_stata_finalise(monkeypatch): def test_stata_finalise_default_filetype(monkeypatch): - """Checks finalise gets called correctly.""" + """Check finalise gets called correctly.""" monkeypatch.setattr("builtins.input", lambda _: "Let me have it") ret = dummy_acrohandler( data, @@ -695,7 +693,7 @@ def test_stata_finalise_default_filetype(monkeypatch): def test_stata_unknown(data): - """Unknown acro command.""" + """Test unknown acro command.""" ret = dummy_acrohandler( data, command="foo", @@ -712,7 +710,7 @@ def test_stata_unknown(data): # ----Test stata 17 new table command syntax------------------------------------- def test_table_stata17(data): - """Checks that the simple table command works as expected.""" + """Check that the simple table command works as expected.""" correct = ( "Total\n" "------------------------------------|\n" @@ -737,7 +735,7 @@ def test_table_stata17(data): def test_table_stata17_1(data): - """Checks that the table command works as expected, with more than one index.""" + """Check that the table command works as expected, with more than one index.""" correct = ( "Total\n" "---------------------------------------|\n" @@ -772,7 +770,7 @@ def test_table_stata17_1(data): def test_table_stata17_2(data): - """Checks that the table command works as expected, with more than one column.""" + """Check that the table command works as expected, with more than one column.""" correct = ( "Total\n" "--------------------------------------------|\n" @@ -816,7 +814,7 @@ def test_table_stata17_2(data): def test_table_stata17_3(data): - """Checks that the table command works as expected, with herichical tables.""" + """Check that the table command works as expected, with herichical tables.""" correct = ( "Total\n" "----------------------------------------------------|\n" @@ -853,7 +851,7 @@ def test_table_stata17_3(data): def test_table_stata17_4(data): - """Checks that the table command works as expected, with the table variable.""" + """Check that the table command works as expected, with the table variable.""" correct = ( "You need to manually check all the outputs for the risk of differencing.\n" "Total\n" @@ -958,7 +956,7 @@ def test_table_stata17_4(data): def test_one_dimensional_table(data): - """Checks that one dimensional table is not supported at the moment.""" + """Check that one dimensional table is not supported at the moment.""" correct = ( "acro does not currently support one dimensioanl tables. " "To calculate cross tabulation, you need to provide at " @@ -978,7 +976,7 @@ def test_one_dimensional_table(data): def test_cleanup(): - """Gets rid of files created during tests.""" + """Remove files created during tests.""" names = ["test_outputs", "test_add_to_acro", "sdc_results", "RES_PYTEST"] for name in names: clean_up(name) diff --git a/test/test_stata_interface.py b/test/test_stata_interface.py index 99f66d2..2cb80d4 100644 --- a/test/test_stata_interface.py +++ b/test/test_stata_interface.py @@ -1,4 +1,4 @@ -"""This module contains unit tests for the stata interface.""" +"""Unit tests for the stata interface.""" import os import shutil @@ -35,7 +35,7 @@ def data() -> pd.DataFrame: def clean_up(name): - """Removes unwanted files or directory.""" + """Remove unwanted files or directory.""" if os.path.exists(name): if os.path.isfile(name): os.remove(name) @@ -47,7 +47,8 @@ def dummy_acrohandler( data, command, varlist, exclusion, exp, weights, options, stata_version ): # pylint:disable=too-many-arguments """ - Provides an alternative interface that mimics the code in acro.ado + Provide an alternative interface that mimics the code in acro.ado. + Most notably the presence of a global variable called stata_acro. """ acro_outstr = parse_and_run( @@ -59,11 +60,7 @@ def dummy_acrohandler( # --- Helper functions----------------------------------------------------- def test_find_brace_word(): - """Tests helper function - that extracts contents 'A B C' - of something specified via X(A B C) - on the stata command line. - """ + """Extract contents 'A B C' specified as X(A B C) on the stata command line.""" options = "by(grant_type) contents(mean sd inc_activity) suppress nototals" res, substr = find_brace_word("by", options) assert res @@ -82,7 +79,7 @@ def test_find_brace_word(): def test_apply_stata_ifstmt(data): - """Tests that if statements work for selection.""" + """Test that if statements work for selection.""" # empty ifstring ifstring = "" smaller = apply_stata_ifstmt(ifstring, data) @@ -101,7 +98,7 @@ def test_apply_stata_ifstmt(data): def test_apply_stata_expstmt(): - """Tests that in statements work for row selection.""" + """Test that in statements work for row selection.""" data = np.zeros(100) for i in range(100): data[i] = i @@ -160,11 +157,7 @@ def test_apply_stata_expstmt(): def test_parse_table_details(data): - """ - Series of checks that the varlist and options are parsed correctly - by the helper function. - """ - + """Check that the varlist and options are parsed correctly by the helper function.""" varlist = ["survivor", "grant_type", "year"] varnames = data.columns options = "by(grant_type) contents(mean sd inc_activity) suppress nototals" @@ -207,7 +200,8 @@ def test_parse_table_details(data): def test_stata_acro_init(): """ - Tests creation of an acro object at the start of a session + Test creation of an acro object at the start of a session. + For stata this gets held in a variable stata_acro Which is initialsied to the string "empty" in the acro.ado file Then should be pointed at a new acro instance. @@ -231,7 +225,7 @@ def test_stata_acro_init(): def test_stata_print_outputs(data): - """Checks print_outputs gets called.""" + """Check print_outputs gets called.""" ret = dummy_acrohandler( data, command="print_outputs", @@ -248,7 +242,8 @@ def test_stata_print_outputs(data): # ----main SDC functionality------------------------------------- def test_simple_table(data): """ - Checks that the simple table command works as expected + Check that the simple table command works as expected. + Does via reference to direct call to pd.crosstab() To make sure table specification is parsed correctly acro SDC analysis is tested elsewhere. @@ -279,8 +274,9 @@ def test_simple_table(data): def test_stata_rename_outputs(): - """Tests renaming outputs - assumes simple table has been created by earlier tests. + """Test renaming outputs. + + Assumes simple table has been created by earlier tests. """ the_str = "renamed_output" the_output = "output_0" @@ -301,8 +297,9 @@ def test_stata_rename_outputs(): def test_stata_incomplete_output_commands(): - """Tests handling incomplete or wrong output commands - assumes simple table has been created by earlier tests. + """Test handling incomplete or wrong output commands. + + Assumes simple table has been created by earlier tests. """ # output to change not provided the_str = "renamed_output" @@ -358,9 +355,9 @@ def test_stata_incomplete_output_commands(): def test_stata_add_comments(): """ - Tests adding comments to outputs - assumes simple table has been created by earlier tests - then renamed. + Test adding comments to outputs. + + Assumes simple table has been created by earlier tests then renamed. """ the_str = "some comments" the_output = "renamed_output" @@ -382,9 +379,9 @@ def test_stata_add_comments(): def test_stata_add_exception(): """ - Tests adding exception to outputs - assumes simple table has been created by earlier tests - then renamed. + Test adding exception to outputs. + + Assumes simple table has been created by earlier tests then renamed. """ the_str = "a reason" the_output = "renamed_output" @@ -406,8 +403,9 @@ def test_stata_add_exception(): def test_stata_remove_output(): """ - Tests removing outputs - assumes simple table has been created and renamed by earlier tests. + Test removing outputs. + + Assumes simple table has been created and renamed by earlier tests. """ the_output = "renamed_output" ret = dummy_acrohandler( @@ -430,7 +428,7 @@ def test_stata_remove_output(): def test_stata_exclusion_in_context(data): - """Tests that the subsetting code gets called properly from table handler.""" + """Test that the subsetting code gets called properly from table handler.""" # if condition correct1 = ( "Total\n" @@ -508,7 +506,7 @@ def test_stata_exclusion_in_context(data): def test_table_weights(data): - """Weights are not currently supported.""" + """Test weights are not currently supported.""" weights = [0, 0, 0] correct = f"weights not currently implemented for _{weights}_\n" ret = dummy_acrohandler( @@ -525,7 +523,7 @@ def test_table_weights(data): def test_table_aggcfn(data): - """Testing behaviour with aggregation function.""" + """Test behaviour with aggregation function.""" # ok correct = ( "Total\n" @@ -603,8 +601,7 @@ def test_table_aggcfn(data): def test_table_invalidvar(data): - """Checking table details are valid.""" - + """Check table details are valid.""" correct = "Error: word foo in by-list is not a variables name" ret = dummy_acrohandler( data, @@ -620,7 +617,7 @@ def test_table_invalidvar(data): def test_stata_probit(data): - """Checks probit gets called correctly.""" + """Check probit gets called correctly.""" ret = dummy_acrohandler( data, command="probit", @@ -646,7 +643,7 @@ def test_stata_probit(data): def test_stata_linregress(data): - """Checks linear regression called correctly.""" + """Check linear regression called correctly.""" ret = dummy_acrohandler( data, command="regress", @@ -667,7 +664,7 @@ def test_stata_linregress(data): def test_stata_logit(data): - """Tests stata logit function.""" + """Test stata logit function.""" ret = dummy_acrohandler( data, command="logit", @@ -694,7 +691,7 @@ def test_stata_logit(data): def test_unsupported_formatting_options(data): - """Checks that user gets warning if they try to format table.""" + """Check that user gets warning if they try to format table.""" format_string = "acro does not currently support table formatting commands." correct = ( "Total\n" @@ -737,7 +734,7 @@ def test_unsupported_formatting_options(data): def test_stata_finalise(monkeypatch): - """Checks finalise gets called correctly.""" + """Check finalise gets called correctly.""" monkeypatch.setattr("builtins.input", lambda _: "Let me have it") ret = dummy_acrohandler( data, @@ -754,7 +751,7 @@ def test_stata_finalise(monkeypatch): def test_stata_finalise_default_filetype(monkeypatch): - """Checks finalise gets called correctly.""" + """Check finalise gets called correctly.""" monkeypatch.setattr("builtins.input", lambda _: "Let me have it") ret = dummy_acrohandler( data, @@ -771,7 +768,7 @@ def test_stata_finalise_default_filetype(monkeypatch): def test_stata_unknown(data): - """Unknown acro command.""" + """Test unknown acro command.""" ret = dummy_acrohandler( data, command="foo", @@ -787,7 +784,7 @@ def test_stata_unknown(data): def test_cleanup(): - """Gets rid of files created during tests.""" + """Remove files created during tests.""" names = ["test_outputs", "test_add_to_acro", "sdc_results", "RES_PYTEST"] for name in names: clean_up(name)