Skip to content

Commit

Permalink
Replace TaggedFilesHierarchy with os.walk and implement configure_dir…
Browse files Browse the repository at this point in the history
…ectory entrypoint (#695)

This PR adds a configure_directory entry point, as well as tests. It
also removes TaggedFilesHierarchy and replaces it with os.walk. Finally,
the Generator tests have been refactored.

[ reviewed by @MattToast @mellis13 @juliaputko ]
[ committed by @amandarichardsonn ]
  • Loading branch information
amandarichardsonn authored Sep 27, 2024
1 parent 4faf95c commit 4d9ab27
Show file tree
Hide file tree
Showing 9 changed files with 735 additions and 621 deletions.
6 changes: 4 additions & 2 deletions smartsim/_core/commands/command_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
class CommandList(MutableSequence[Command]):
"""Container for a Sequence of Command objects"""

def __init__(self, commands: t.Union[Command, t.List[Command]]):
def __init__(self, commands: t.Optional[t.Union[Command, t.List[Command]]] = None):
"""CommandList constructor"""
if isinstance(commands, Command):
if commands is None:
commands = []
elif isinstance(commands, Command):
commands = [commands]
self._commands: t.List[Command] = list(commands)

Expand Down
74 changes: 53 additions & 21 deletions smartsim/_core/entrypoints/file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _abspath(input_path: str) -> pathlib.Path:
"""Helper function to check that paths are absolute"""
path = pathlib.Path(input_path)
if not path.is_absolute():
raise ValueError(f"path `{path}` must be absolute")
raise ValueError(f"Path `{path}` must be absolute.")
return path


Expand All @@ -62,6 +62,22 @@ def _make_substitution(
)


def _prepare_param_dict(param_dict: str) -> dict[str, t.Any]:
"""Decode and deserialize a base64-encoded parameter dictionary.
This function takes a base64-encoded string representation of a dictionary,
decodes it, and then deserializes it using pickle. It performs validation
to ensure the resulting object is a non-empty dictionary.
"""
decoded_dict = base64.b64decode(param_dict)
deserialized_dict = pickle.loads(decoded_dict)
if not isinstance(deserialized_dict, dict):
raise TypeError("param dict is not a valid dictionary")
if not deserialized_dict:
raise ValueError("param dictionary is empty")
return deserialized_dict


def _replace_tags_in(
item: str,
substitutions: t.Sequence[Callable[[str], str]],
Expand All @@ -70,6 +86,23 @@ def _replace_tags_in(
return functools.reduce(lambda a, fn: fn(a), substitutions, item)


def _process_file(
substitutions: t.Sequence[Callable[[str], str]],
source: pathlib.Path,
destination: pathlib.Path,
) -> None:
"""
Process a source file by replacing tags with specified substitutions and
write the result to a destination file.
"""
# Set the lines to iterate over
with open(source, "r+", encoding="utf-8") as file_stream:
lines = [_replace_tags_in(line, substitutions) for line in file_stream]
# write configured file to destination specified
with open(destination, "w+", encoding="utf-8") as file_stream:
file_stream.writelines(lines)


def move(parsed_args: argparse.Namespace) -> None:
"""Move a source file or directory to another location. If dest is an
existing directory or a symlink to a directory, then the srouce will
Expand Down Expand Up @@ -155,9 +188,9 @@ def symlink(parsed_args: argparse.Namespace) -> None:

def configure(parsed_args: argparse.Namespace) -> None:
"""Set, search and replace the tagged parameters for the
configure operation within tagged files attached to an entity.
configure_file operation within tagged files attached to an entity.
User-formatted files can be attached using the `configure` argument.
User-formatted files can be attached using the `configure_file` argument.
These files will be modified during ``Application`` generation to replace
tagged sections in the user-formatted files with values from the `params`
initializer argument used during ``Application`` creation:
Expand All @@ -166,39 +199,38 @@ def configure(parsed_args: argparse.Namespace) -> None:
.. highlight:: bash
.. code-block:: bash
python -m smartsim._core.entrypoints.file_operations \
configure /absolute/file/source/pat /absolute/file/dest/path \
configure_file /absolute/file/source/path /absolute/file/dest/path \
tag_deliminator param_dict
/absolute/file/source/path: The tagged files the search and replace operations
to be performed upon
/absolute/file/dest/path: The destination for configured files to be
written to.
tag_delimiter: tag for the configure operation to search for, defaults to
tag_delimiter: tag for the configure_file operation to search for, defaults to
semi-colon e.g. ";"
param_dict: A dict of parameter names and values set for the file
"""
tag_delimiter = parsed_args.tag_delimiter

decoded_dict = base64.b64decode(parsed_args.param_dict)
param_dict = pickle.loads(decoded_dict)

if not param_dict:
raise ValueError("param dictionary is empty")
if not isinstance(param_dict, dict):
raise TypeError("param dict is not a valid dictionary")
param_dict = _prepare_param_dict(parsed_args.param_dict)

substitutions = tuple(
_make_substitution(k, v, tag_delimiter) for k, v in param_dict.items()
)

# Set the lines to iterate over
with open(parsed_args.source, "r+", encoding="utf-8") as file_stream:
lines = [_replace_tags_in(line, substitutions) for line in file_stream]

# write configured file to destination specified
with open(parsed_args.dest, "w+", encoding="utf-8") as file_stream:
file_stream.writelines(lines)
if parsed_args.source.is_dir():
for dirpath, _, filenames in os.walk(parsed_args.source):
new_dir_dest = dirpath.replace(
str(parsed_args.source), str(parsed_args.dest), 1
)
os.makedirs(new_dir_dest, exist_ok=True)
for file_name in filenames:
src_file = os.path.join(dirpath, file_name)
dst_file = os.path.join(new_dir_dest, file_name)
print(type(substitutions))
_process_file(substitutions, src_file, dst_file)
else:
dst_file = parsed_args.dest / os.path.basename(parsed_args.source)
_process_file(substitutions, parsed_args.source, dst_file)


def get_parser() -> argparse.ArgumentParser:
Expand Down
Loading

0 comments on commit 4d9ab27

Please sign in to comment.