diff --git a/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala b/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala index e4766d1cff4..2b7dcb9ad21 100644 --- a/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala @@ -70,7 +70,15 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin synchronized { tmpdir = tmp } def pySetLocalTmp(tmp: String): Unit = - synchronized { localTmpdir = tmp } + synchronized { + localTmpdir = tmp + backend match { + case s: SparkBackend => + s.sc.getConf.set("spark.local.dir", tmp) + case _ => + () + } + } def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit = synchronized { diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index e7e14bbe90a..8d8e0092629 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -392,3 +392,23 @@ def get_flags(self, *flags) -> Mapping[str, str]: @abc.abstractmethod def requires_lowering(self): pass + + @property + @abc.abstractmethod + def local_tmpdir(self) -> str: + pass + + @local_tmpdir.setter + @abc.abstractmethod + def local_tmpdir(self, dir: str) -> None: + pass + + @property + @abc.abstractmethod + def remote_tmpdir(self) -> str: + pass + + @remote_tmpdir.setter + @abc.abstractmethod + def remote_tmpdir(self, dir: str) -> None: + pass diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index e3d33e97c17..c0fa7055f96 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -197,8 +197,8 @@ def decode_bytearray(encoded): self._jhc = jhc self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend) - self._jbackend.pySetLocalTmp(tmpdir) - self._jbackend.pySetRemoteTmp(remote_tmpdir) + self.local_tmpdir = tmpdir + self.remote_tmpdir = tmpdir self._jhttp_server = self._jbackend.pyHttpServer() self._backend_server_port: int = self._jhttp_server.port() @@ -325,3 +325,21 @@ def stop(self): self._jhc = None uninstall_exception_handler() super().stop() + + @property + def local_tmpdir(self) -> str: + return self._local_tmpdir + + @local_tmpdir.setter + def local_tmpdir(self, tmpdir: str) -> None: + self._local_tmpdir = tmpdir + self._jbackend.pySetLocalTmp(tmpdir) + + @property + def remote_tmpdir(self) -> str: + return self._remote_tmpdir + + @remote_tmpdir.setter + def remote_tmpdir(self, tmpdir: str) -> None: + self._remote_tmpdir = tmpdir + self._jbackend.pySetRemoteTmp(tmpdir) diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 595811d67e5..da6da9c3b4a 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -5,12 +5,12 @@ import warnings from contextlib import AsyncExitStack from dataclasses import dataclass -from typing import Any, Awaitable, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union +from typing import Any, Awaitable, Dict, List, Mapping, NoReturn, Optional, Set, Tuple, TypeVar, Union import orjson import hailtop.aiotools.fs as afs -from hail.context import TemporaryDirectory, TemporaryFilename, tmp_dir +from hail.context import TemporaryDirectory, TemporaryFilename from hail.experimental import read_expression, write_expression from hail.utils import FatalError from hail.version import __revision__, __version__ @@ -240,7 +240,7 @@ def __init__( self._batch_was_submitted: bool = False self.disable_progress_bar = disable_progress_bar self.batch_attributes = batch_attributes - self.remote_tmpdir = remote_tmpdir + self._remote_tmpdir = remote_tmpdir self.flags: Dict[str, str] = {} self._registered_ir_function_names: Set[str] = set() self.driver_cores = driver_cores @@ -520,3 +520,19 @@ def get_flags(self, *flags: str) -> Mapping[str, str]: @property def requires_lowering(self): return True + + @property + def local_tmpdir(self) -> NoReturn: + raise AttributeError('local tmp folders are not supported on the batch backend') + + @local_tmpdir.setter + def local_tmpdir(self, tmpdir: str) -> NoReturn: + raise AttributeError('local tmp folders are not supported on the batch backend') + + @property + def remote_tmpdir(self) -> str: + return self._remote_tmpdir + + @remote_tmpdir.setter + def remote_tmpdir(self, tmpdir: str) -> None: + self._remote_tmpdir = tmpdir diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index b3d72e2cef4..fdf308351f9 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -66,8 +66,6 @@ def create( log: str, quiet: bool, append: bool, - tmpdir: str, - local_tmpdir: str, default_reference: str, global_seed: Optional[int], backend: Backend, @@ -76,25 +74,17 @@ def create( log=log, quiet=quiet, append=append, - tmpdir=tmpdir, - local_tmpdir=local_tmpdir, global_seed=global_seed, backend=backend, ) hc.initialize_references(default_reference) return hc - @typecheck_method( - log=str, quiet=bool, append=bool, tmpdir=str, local_tmpdir=str, global_seed=nullable(int), backend=Backend - ) - def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backend): + @typecheck_method(log=str, quiet=bool, append=bool, global_seed=nullable(int), backend=Backend) + def __init__(self, log, quiet, append, global_seed, backend: Backend): assert not Env._hc self._log = log - - self._tmpdir = tmpdir - self._local_tmpdir = local_tmpdir - self._backend = backend self._warn_cols_order = True @@ -136,6 +126,14 @@ def initialize_references(self, default_reference): else: self._default_ref = ReferenceGenome.read(default_reference) + @property + def _tmpdir(self) -> str: + return self._backend.remote_tmpdir + + @property + def _local_tmpdir(self) -> str: + return self._backend.local_tmpdir + @property def default_reference(self) -> ReferenceGenome: assert self._default_ref is not None, '_default_ref should have been initialized in HailContext.create' @@ -498,7 +496,7 @@ def init_spark( if not backend.fs.exists(tmpdir): backend.fs.mkdir(tmpdir) - HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend) + HailContext.create(log, quiet, append, default_reference, global_seed, backend) if not quiet: connect_logger(backend._utils_package_object, 'localhost', 12888) @@ -569,7 +567,7 @@ async def init_batch( tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string()) local_tmpdir = _get_local_tmpdir(local_tmpdir) - HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend) + HailContext.create(log, quiet, append, default_reference, global_seed, backend) @typecheck( @@ -621,7 +619,7 @@ def init_local( if not backend.fs.exists(tmpdir): backend.fs.mkdir(tmpdir) - HailContext.create(log, quiet, append, tmpdir, tmpdir, default_reference, global_seed, backend) + HailContext.create(log, quiet, append, default_reference, global_seed, backend) if not quiet: connect_logger(backend._utils_package_object, 'localhost', 12888)