diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst index 206d3fab..afad9507 100644 --- a/doc/source/changelog.rst +++ b/doc/source/changelog.rst @@ -1,6 +1,13 @@ MontePy Changelog ================= +#Next Version# +--------------------- + +**Features Added** + +* ``overwrite`` argument added to `MCNP_Problem.write_to_file` to ensure files are only overwritten if the user really wants to do so (:pull:`443`). + 0.2.10 ---------------------- diff --git a/doc/source/starting.rst b/doc/source/starting.rst index 2a93aec5..a521d008 100644 --- a/doc/source/starting.rst +++ b/doc/source/starting.rst @@ -79,8 +79,12 @@ state as a valid MCNP input file. >>> problem.write_to_file("bar.imcnp") +The :func:`~montepy.mcnp_problem.MCNP_Problem.write_to_file` method does take an optional argument: ``overwrite``. +By default if the file exists, it will not be overwritten and an error will be raised. +This can be changed by ``overwrite=True``. + .. warning:: - Be careful with overwriting the original file when writing a modified file out. + Be careful with overwriting the original file (and ``overwrite=True`` in general) when writing a modified file out. This will wipe out the original version, and if you have no version control, may lead to losing information. diff --git a/montepy/input_parser/input_file.py b/montepy/input_parser/input_file.py index 0ef736ce..52240260 100644 --- a/montepy/input_parser/input_file.py +++ b/montepy/input_parser/input_file.py @@ -2,6 +2,7 @@ import itertools as it from montepy.constants import ASCII_CEILING from montepy.utilities import * +import os class MCNP_InputFile: @@ -11,17 +12,23 @@ class MCNP_InputFile: .. Note:: this is a bare bones implementation to be fleshed out in the future. + .. versionchanged:: 0.3.0 + Added the overwrite attribute. + :param path: the path to the input file :type path: str :param parent_file: the parent file for this file if any. This occurs when a "read" input is used. :type parent_file: str + :param overwrite: Whether to overwrite the file 'path' if it exists + :type overwrite: bool """ - def __init__(self, path, parent_file=None): + def __init__(self, path, parent_file=None, overwrite=False): self._path = path self._parent_file = parent_file self._lineno = 1 self._replace_with_space = False + self._overwrite = overwrite self._mode = None self._fh = None @@ -76,6 +83,9 @@ def open(self, mode, encoding="ascii", replace=True): CP1252 is commonly referred to as "extended-ASCII". You may have success with this encoding for working with special characters. + .. versionchanged:: 0.2.11 + Added guardrails to raise FileExistsError and IsADirectoryError. + :param mode: the mode to open the file in :type mode: str :param encoding: The encoding scheme to use. If replace is true, this is ignored, and changed to ASCII @@ -83,6 +93,8 @@ def open(self, mode, encoding="ascii", replace=True): :param replace: replace all non-ASCII characters with a space (0x20) :type replace: bool :returns: self + :raises FileExistsError: if a file already exists with the same path while writing. + :raises IsADirectoryError: if the path given is actually a directory while writing. """ if "r" in mode: if replace: @@ -90,6 +102,15 @@ def open(self, mode, encoding="ascii", replace=True): mode = "rb" encoding = None self._mode = mode + if "w" in mode: + if os.path.isfile(self.path) and self._overwrite is not True: + raise FileExistsError( + f"{self.path} already exists, and overwrite is not set." + ) + if os.path.isdir(self.path): + raise IsADirectoryError( + f"{self.path} is a directory, and cannot be overwritten." + ) self._fh = open(self.path, mode, encoding=encoding) return self diff --git a/montepy/mcnp_problem.py b/montepy/mcnp_problem.py index 16b8426e..403e501b 100644 --- a/montepy/mcnp_problem.py +++ b/montepy/mcnp_problem.py @@ -390,15 +390,22 @@ def add_cell_children_to_problem(self): self._transforms = Transforms(transforms) self._data_inputs = sorted(set(self._data_inputs + materials + transforms)) - def write_to_file(self, new_problem): + def write_to_file(self, new_problem, overwrite=False): """ Writes the problem to a file. + .. versionchanged:: 0.3.0 + The overwrite parameter was added. + :param new_problem: the file name to write this problem to :type new_problem: str + :param overwrite: Whether to overwrite the file at 'new_problem' if it exists + :type overwrite: bool :raises IllegalState: if an object in the problem has not been fully initialized. + :raises FileExistsError: if a file already exists with the same path. + :raises IsADirectoryError: if the path given is actually a directory. """ - new_file = MCNP_InputFile(new_problem) + new_file = MCNP_InputFile(new_problem, overwrite=overwrite) with new_file.open("w") as fh, warnings.catch_warnings( record=True ) as warning_catch: diff --git a/tests/test_input_file.py b/tests/test_input_file.py index 8a6fa1d1..3e0ebdda 100644 --- a/tests/test_input_file.py +++ b/tests/test_input_file.py @@ -1,6 +1,7 @@ # Copyright 2024, Battelle Energy Alliance, LLC All Rights Reserved. import os import unittest +import pytest import montepy from montepy.input_parser.input_file import MCNP_InputFile @@ -69,3 +70,34 @@ def test_write(self): finally: if os.path.exists(out): os.remove(out) + + +def _write_file(out_file): + with open(out_file, "w") as fh: + fh.write("") + + +@pytest.mark.parametrize( + "writer,exception, clearer", + ( + (_write_file, FileExistsError, lambda out_file: os.remove(out_file)), + ( + lambda out_file: os.makedirs(out_file), + IsADirectoryError, + lambda out_file: os.rmdir(out_file), + ), + ), +) +def test_write_guardrails(writer, exception, clearer): + out_file = "foo_bar.imcnp" + try: + writer(out_file) + with pytest.raises(exception): + test = MCNP_InputFile(out_file) + with test.open("w") as _: + pass + finally: + try: + clearer(out_file) + except FileNotFoundError: + pass diff --git a/tests/test_integration.py b/tests/test_integration.py index c4ba5afb..ba79da99 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -545,7 +545,7 @@ def test_importance_write_cell(self): pass def test_importance_write_data(self): - out_file = "test_import_data" + out_file = "test_import_data_2" problem = copy.deepcopy(self.simple_problem) problem.print_in_data_block["imp"] = True try: @@ -968,12 +968,13 @@ def test_cell_validator(self): cell.validate() def test_importance_rewrite(self): - out_file = "test_import_data" + out_file = "test_import_data_1" problem = copy.deepcopy(self.simple_problem) problem.print_in_data_block["imp"] = True try: problem.write_to_file(out_file) problem = montepy.read_input(out_file) + os.remove(out_file) problem.print_in_data_block["imp"] = False problem.write_to_file(out_file) found_n = False