Skip to content

Commit

Permalink
mypy errors partially corrected
Browse files Browse the repository at this point in the history
  • Loading branch information
amandarichardsonn committed Jul 29, 2024
1 parent 91f3af8 commit c1ec227
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
14 changes: 10 additions & 4 deletions smartsim/_core/generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from ...database import FeatureStore
from ...entity import Application, TaggedFilesHierarchy
from ...entity.files import EntityFiles
from ...launchable import Job, JobGroup
from ...log import get_logger
from ..entrypoints import file_operations
Expand Down Expand Up @@ -191,7 +192,7 @@ def generate_experiment(self) -> str:
log_file.write(f"Generation start date and time: {dt_string}\n")

# Prevent access to type FeatureStore entities
if isinstance(self.job.entity, Application) and self.job.entity.files:
if isinstance(self.job.entity, Application):
# Perform file system operations on attached files
self._build_operations()

Expand All @@ -210,8 +211,7 @@ def _build_operations(self) -> None:
app = t.cast(Application, self.job.entity)
self._get_symlink_file_system_operation(app, self.path)
self._write_tagged_entity_files(app, self.path)
if app.files:
self._get_copy_file_system_operation(app, self.path)
self._get_copy_file_system_operation(app, self.path)

@staticmethod
def _get_copy_file_system_operation(app: Application, dest: str) -> None:
Expand All @@ -220,9 +220,11 @@ def _get_copy_file_system_operation(app: Application, dest: str) -> None:
:param linked_file: The file to be copied.
:return: A list of copy file system operations.
"""
if app.files is None:
return
parser = get_parser()
for src in app.files.copy:
if Path(src).is_dir:
if Path(src).is_dir: # TODO figure this out, or how to replace
cmd = f"copy {src} {dest} --dirs_exist_ok"
else:
cmd = f"copy {src} {dest}"
Expand All @@ -237,6 +239,8 @@ def _get_symlink_file_system_operation(app: Application, dest: str) -> None:
:param linked_file: The file to be symlinked.
:return: A list of symlink file system operations.
"""
if app.files is None:
return
parser = get_parser()
for sym in app.files.link:
# Normalize the path to remove trailing slashes
Expand All @@ -259,6 +263,8 @@ def _write_tagged_entity_files(app: Application, dest: str) -> None:
:param entity: a Application instance
"""
if app.files is None:
return
if app.files.tagged:
to_write = []

Expand Down
11 changes: 6 additions & 5 deletions smartsim/settings/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@
a job
"""
_FormatterType: TypeAlias = t.Callable[
[_DispatchableT, "ExecutableProtocol", _EnvironMappingType], _LaunchableT
[_DispatchableT, "ExecutableProtocol", str, _EnvironMappingType], _LaunchableT
]
"""A callable that is capable of formatting the components of a job into a type
capable of being launched by a launcher.
"""
_LaunchConfigType: TypeAlias = (
"_LauncherAdapter[ExecutableProtocol, _EnvironMappingType]"
"_LauncherAdapter[ExecutableProtocol, _EnvironMappingType, str]"
)
"""A launcher adapater that has configured a launcher to launch the components
of a job with some pre-determined launch settings
Expand Down Expand Up @@ -388,7 +388,7 @@ def create(cls, exp: Experiment, /) -> Self: ...

def make_shell_format_fn(
run_command: str | None,
) -> _FormatterType[LaunchArguments, t.Sequence[str]]:
) -> _FormatterType[LaunchArguments, tuple[t.Sequence[str], str]]:
"""A function that builds a function that formats a `LaunchArguments` as a
shell executable sequence of strings for a given launching utility.
Expand Down Expand Up @@ -423,7 +423,7 @@ def impl(
exe: ExecutableProtocol,
path: str,
_env: _EnvironMappingType,
) -> t.Sequence[str]:
) -> t.Tuple[t.Sequence[str], str]:
return (
(
run_command,
Expand All @@ -444,7 +444,8 @@ class ShellLauncher:
def __init__(self) -> None:
self._launched: dict[LaunchedJobID, sp.Popen[bytes]] = {}

def start(self, command: t.Sequence[str]) -> LaunchedJobID:
# TODO inject path here
def start(self, command: tuple[t.Sequence[str], str]) -> LaunchedJobID:
id_ = create_job_id()
exe, *rest = command
print(f"here is the path: {rest}")
Expand Down

0 comments on commit c1ec227

Please sign in to comment.