diff --git a/control/grpc.py b/control/grpc.py index f85c39b8..d57a8c04 100644 --- a/control/grpc.py +++ b/control/grpc.py @@ -36,7 +36,7 @@ class GatewayService(pb2_grpc.GatewayServicer): spdk_rpc_client: Client of SPDK RPC server """ - def __init__(self, config, gateway_state, spdk_rpc_client) -> None: + def __init__(self, config, gateway_state, spdk_rpc_client, rpc_lock) -> None: """Constructor""" self.logger = logging.getLogger(__name__) ver = os.getenv("NVMEOF_VERSION") @@ -44,6 +44,7 @@ def __init__(self, config, gateway_state, spdk_rpc_client) -> None: self.logger.info(f"Using NVMeoF gateway version {ver}") self.config = config self.logger.info(f"Using configuration file {config.filepath}") + self.rpc_lock = rpc_lock self.gateway_state = gateway_state self.spdk_rpc_client = spdk_rpc_client self.gateway_name = self.config.get("gateway", "name") @@ -91,7 +92,7 @@ def _alloc_cluster(self) -> str: ) return name - def create_bdev(self, request, context=None): + def create_bdev_safe(self, request, context=None): """Creates a bdev from an RBD image.""" if not request.uuid: @@ -132,7 +133,15 @@ def create_bdev(self, request, context=None): return pb2.bdev(bdev_name=bdev_name, status=True) - def delete_bdev(self, request, context=None): + def create_bdev(self, request, context=None): + if context: + with self.rpc_lock: + return self.create_bdev_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.create_bdev_safe(request, context) + + def delete_bdev_safe(self, request, context=None): """Deletes a bdev.""" self.logger.info(f"Received request to delete bdev {request.bdev_name}") @@ -191,7 +200,15 @@ def delete_bdev(self, request, context=None): return pb2.req_status(status=ret) - def create_subsystem(self, request, context=None): + def delete_bdev(self, request, context=None): + if context: + with self.rpc_lock: + return self.delete_bdev_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.delete_bdev_safe(request, context) + + def create_subsystem_safe(self, request, context=None): """Creates a subsystem.""" self.logger.info( @@ -233,7 +250,15 @@ def create_subsystem(self, request, context=None): return pb2.req_status(status=ret) - def delete_subsystem(self, request, context=None): + def create_subsystem(self, request, context=None): + if context: + with self.rpc_lock: + return self.create_subsystem_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.create_subsystem_safe(request, context) + + def delete_subsystem_safe(self, request, context=None): """Deletes a subsystem.""" self.logger.info( @@ -262,7 +287,15 @@ def delete_subsystem(self, request, context=None): return pb2.req_status(status=ret) - def add_namespace(self, request, context=None): + def delete_subsystem(self, request, context=None): + if context: + with self.rpc_lock: + return self.delete_subsystem_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.delete_subsystem_safe(request, context) + + def add_namespace_safe(self, request, context=None): """Adds a namespace to a subsystem.""" self.logger.info(f"Received request to add {request.bdev_name} to" @@ -298,7 +331,15 @@ def add_namespace(self, request, context=None): return pb2.nsid(nsid=nsid, status=True) - def remove_namespace(self, request, context=None): + def add_namespace(self, request, context=None): + if context: + with self.rpc_lock: + return self.add_namespace_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.add_namespace_safe(request, context) + + def remove_namespace_safe(self, request, context=None): """Removes a namespace from a subsystem.""" self.logger.info(f"Received request to remove {request.nsid} from" @@ -329,7 +370,15 @@ def remove_namespace(self, request, context=None): return pb2.req_status(status=ret) - def add_host(self, request, context=None): + def remove_namespace(self, request, context=None): + if context: + with self.rpc_lock: + return self.remove_namespace_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.remove_namespace_safe(request, context) + + def add_host_safe(self, request, context=None): """Adds a host to a subsystem.""" try: @@ -373,7 +422,15 @@ def add_host(self, request, context=None): return pb2.req_status(status=ret) - def remove_host(self, request, context=None): + def add_host(self, request, context=None): + if context: + with self.rpc_lock: + return self.add_host_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.add_host_safe(request, context) + + def remove_host_safe(self, request, context=None): """Removes a host from a subsystem.""" try: @@ -415,7 +472,15 @@ def remove_host(self, request, context=None): return pb2.req_status(status=ret) - def create_listener(self, request, context=None): + def remove_host(self, request, context=None): + if context: + with self.rpc_lock: + return self.remove_host_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.remove_host_safe(request, context) + + def create_listener_safe(self, request, context=None): """Creates a listener for a subsystem at a given IP/Port.""" ret = True @@ -459,7 +524,15 @@ def create_listener(self, request, context=None): return pb2.req_status(status=ret) - def delete_listener(self, request, context=None): + def create_listener(self, request, context=None): + if context: + with self.rpc_lock: + return self.create_listener_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.create_listener_safe(request, context) + + def delete_listener_safe(self, request, context=None): """Deletes a listener from a subsystem at a given IP/Port.""" ret = True @@ -502,7 +575,15 @@ def delete_listener(self, request, context=None): return pb2.req_status(status=ret) - def get_subsystems(self, request, context): + def delete_listener(self, request, context=None): + if context: + with self.rpc_lock: + return self.delete_listener_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.delete_listener_safe(request, context) + + def get_subsystems_safe(self, request, context): """Gets subsystems.""" self.logger.info(f"Received request to get subsystems") @@ -516,3 +597,7 @@ def get_subsystems(self, request, context): return pb2.subsystems_info() return pb2.subsystems_info(subsystems=json.dumps(ret)) + + def get_subsystems(self, request, context): + with self.rpc_lock: + return self.get_subsystems_safe(request, context) diff --git a/control/server.py b/control/server.py index 7057c28e..0668980f 100644 --- a/control/server.py +++ b/control/server.py @@ -16,6 +16,8 @@ import json import logging import signal +import threading +import copy from concurrent import futures from google.protobuf import json_format @@ -45,6 +47,7 @@ def sigchld_handler(signum, frame): # GW process should exit now raise SystemExit(f"Gateway subprocess terminated {pid=} {exit_code=}") + class GatewayServer: """Runs SPDK and receives client requests for the gateway service. @@ -59,9 +62,37 @@ class GatewayServer: discovery_pid: Subprocess running Ceph nvmeof discovery service """ + class RPCGuard: + RPC_GUARD_LOCK_TIMEOUT = 300 + + def __init__(self, logger, timeout = None) -> None: + self.rpc_lock = threading.Lock() + self.lock_timeout = timeout + self.logger = logger + + def __enter__(self): + rc = self.rpc_lock.acquire(True, self.lock_timeout) + if not rc: + self.logger.warning(f"Couldn't acquire lock after {self.lock_timeout} seconds, will try again") + rc = self.rpc_lock.acquire(True, self.lock_timeout) + if not rc: + self.logger.error(f"Failed to acquire lock for guarding RPC, will continue anyway") + return self + + def __exit__(self, typ, value, traceback): + self.rpc_lock.release() + + def is_locked(self): + return self.rpc_lock.locked() + + def raise_exception_if_not_locked(self): + if not self.rpc_lock.locked(): + raise Exception("RPC guard is not locked like it should be") + def __init__(self, config): self.logger = logging.getLogger(__name__) self.config = config + self.rpc_lock = GatewayServer.RPCGuard(self.logger, GatewayServer.RPCGuard.RPC_GUARD_LOCK_TIMEOUT) self.spdk_process = None self.gateway_rpc = None self.server = None @@ -113,7 +144,7 @@ def serve(self): gateway_state = GatewayStateHandler(self.config, local_state, omap_state, self.gateway_rpc_caller) self.gateway_rpc = GatewayService(self.config, gateway_state, - self.spdk_rpc_client) + self.spdk_rpc_client, self.rpc_lock) self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) pb2_grpc.add_GatewayServicer_to_server(self.gateway_rpc, self.server) @@ -328,45 +359,51 @@ def _ping(self): def gateway_rpc_caller(self, requests, is_add_req): """Passes RPC requests to gateway service.""" - for key, val in requests.items(): + requests_copy = copy.deepcopy(requests) + for key, val in requests_copy.items(): if key.startswith(GatewayState.BDEV_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.create_bdev_req()) - self.gateway_rpc.create_bdev(req) - else: - req = json_format.Parse(val, + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.create_bdev_req()) + self.gateway_rpc.create_bdev(req) + else: + req = json_format.Parse(val, pb2.delete_bdev_req(), ignore_unknown_fields=True) - self.gateway_rpc.delete_bdev(req) + self.gateway_rpc.delete_bdev(req) elif key.startswith(GatewayState.SUBSYSTEM_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.create_subsystem_req()) - self.gateway_rpc.create_subsystem(req) - else: - req = json_format.Parse(val, + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.create_subsystem_req()) + self.gateway_rpc.create_subsystem(req) + else: + req = json_format.Parse(val, pb2.delete_subsystem_req(), ignore_unknown_fields=True) - self.gateway_rpc.delete_subsystem(req) + self.gateway_rpc.delete_subsystem(req) elif key.startswith(GatewayState.NAMESPACE_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.add_namespace_req()) - self.gateway_rpc.add_namespace(req) - else: - req = json_format.Parse(val, + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.add_namespace_req()) + self.gateway_rpc.add_namespace(req) + else: + req = json_format.Parse(val, pb2.remove_namespace_req(), ignore_unknown_fields=True) - self.gateway_rpc.remove_namespace(req) + self.gateway_rpc.remove_namespace(req) elif key.startswith(GatewayState.HOST_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.add_host_req()) - self.gateway_rpc.add_host(req) - else: - req = json_format.Parse(val, pb2.remove_host_req()) - self.gateway_rpc.remove_host(req) + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.add_host_req()) + self.gateway_rpc.add_host(req) + else: + req = json_format.Parse(val, pb2.remove_host_req()) + self.gateway_rpc.remove_host(req) elif key.startswith(GatewayState.LISTENER_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.create_listener_req()) - self.gateway_rpc.create_listener(req) - else: - req = json_format.Parse(val, pb2.delete_listener_req()) - self.gateway_rpc.delete_listener(req) + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.create_listener_req()) + self.gateway_rpc.create_listener(req) + else: + req = json_format.Parse(val, pb2.delete_listener_req()) + self.gateway_rpc.delete_listener(req)