Skip to content

Commit

Permalink
Add a lock for GRPC calls to prevent corruption and exceptions on gat…
Browse files Browse the repository at this point in the history
…eway restart.

Fixes #255

Signed-off-by: Gil Bregman <[email protected]>
  • Loading branch information
gbregman committed Oct 10, 2023
1 parent d24a890 commit 316c2f9
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 44 deletions.
109 changes: 97 additions & 12 deletions control/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ class GatewayService(pb2_grpc.GatewayServicer):
spdk_rpc_client: Client of SPDK RPC server
"""

def __init__(self, config, gateway_state, spdk_rpc_client) -> None:
def __init__(self, config, gateway_state, spdk_rpc_client, rpc_lock) -> None:
"""Constructor"""
self.logger = logging.getLogger(__name__)
ver = os.getenv("NVMEOF_VERSION")
if ver:
self.logger.info(f"Using NVMeoF gateway version {ver}")
self.config = config
self.logger.info(f"Using configuration file {config.filepath}")
self.rpc_lock = rpc_lock
self.gateway_state = gateway_state
self.spdk_rpc_client = spdk_rpc_client
self.gateway_name = self.config.get("gateway", "name")
Expand Down Expand Up @@ -91,7 +92,7 @@ def _alloc_cluster(self) -> str:
)
return name

def create_bdev(self, request, context=None):
def create_bdev_safe(self, request, context=None):
"""Creates a bdev from an RBD image."""

if not request.uuid:
Expand Down Expand Up @@ -132,7 +133,15 @@ def create_bdev(self, request, context=None):

return pb2.bdev(bdev_name=bdev_name, status=True)

def delete_bdev(self, request, context=None):
def create_bdev(self, request, context=None):
if context:
with self.rpc_lock:
return self.create_bdev_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.create_bdev_safe(request, context)

def delete_bdev_safe(self, request, context=None):
"""Deletes a bdev."""

self.logger.info(f"Received request to delete bdev {request.bdev_name}")
Expand Down Expand Up @@ -191,7 +200,15 @@ def delete_bdev(self, request, context=None):

return pb2.req_status(status=ret)

def create_subsystem(self, request, context=None):
def delete_bdev(self, request, context=None):
if context:
with self.rpc_lock:
return self.delete_bdev_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.delete_bdev_safe(request, context)

def create_subsystem_safe(self, request, context=None):
"""Creates a subsystem."""

self.logger.info(
Expand Down Expand Up @@ -233,7 +250,15 @@ def create_subsystem(self, request, context=None):

return pb2.req_status(status=ret)

def delete_subsystem(self, request, context=None):
def create_subsystem(self, request, context=None):
if context:
with self.rpc_lock:
return self.create_subsystem_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.create_subsystem_safe(request, context)

def delete_subsystem_safe(self, request, context=None):
"""Deletes a subsystem."""

self.logger.info(
Expand Down Expand Up @@ -262,7 +287,15 @@ def delete_subsystem(self, request, context=None):

return pb2.req_status(status=ret)

def add_namespace(self, request, context=None):
def delete_subsystem(self, request, context=None):
if context:
with self.rpc_lock:
return self.delete_subsystem_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.delete_subsystem_safe(request, context)

def add_namespace_safe(self, request, context=None):
"""Adds a namespace to a subsystem."""

self.logger.info(f"Received request to add {request.bdev_name} to"
Expand Down Expand Up @@ -298,7 +331,15 @@ def add_namespace(self, request, context=None):

return pb2.nsid(nsid=nsid, status=True)

def remove_namespace(self, request, context=None):
def add_namespace(self, request, context=None):
if context:
with self.rpc_lock:
return self.add_namespace_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.add_namespace_safe(request, context)

def remove_namespace_safe(self, request, context=None):
"""Removes a namespace from a subsystem."""

self.logger.info(f"Received request to remove {request.nsid} from"
Expand Down Expand Up @@ -329,7 +370,15 @@ def remove_namespace(self, request, context=None):

return pb2.req_status(status=ret)

def add_host(self, request, context=None):
def remove_namespace(self, request, context=None):
if context:
with self.rpc_lock:
return self.remove_namespace_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.remove_namespace_safe(request, context)

def add_host_safe(self, request, context=None):
"""Adds a host to a subsystem."""

try:
Expand Down Expand Up @@ -373,7 +422,15 @@ def add_host(self, request, context=None):

return pb2.req_status(status=ret)

def remove_host(self, request, context=None):
def add_host(self, request, context=None):
if context:
with self.rpc_lock:
return self.add_host_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.add_host_safe(request, context)

def remove_host_safe(self, request, context=None):
"""Removes a host from a subsystem."""

try:
Expand Down Expand Up @@ -415,7 +472,15 @@ def remove_host(self, request, context=None):

return pb2.req_status(status=ret)

def create_listener(self, request, context=None):
def remove_host(self, request, context=None):
if context:
with self.rpc_lock:
return self.remove_host_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.remove_host_safe(request, context)

def create_listener_safe(self, request, context=None):
"""Creates a listener for a subsystem at a given IP/Port."""

ret = True
Expand Down Expand Up @@ -459,7 +524,15 @@ def create_listener(self, request, context=None):

return pb2.req_status(status=ret)

def delete_listener(self, request, context=None):
def create_listener(self, request, context=None):
if context:
with self.rpc_lock:
return self.create_listener_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.create_listener_safe(request, context)

def delete_listener_safe(self, request, context=None):
"""Deletes a listener from a subsystem at a given IP/Port."""

ret = True
Expand Down Expand Up @@ -502,7 +575,15 @@ def delete_listener(self, request, context=None):

return pb2.req_status(status=ret)

def get_subsystems(self, request, context):
def delete_listener(self, request, context=None):
if context:
with self.rpc_lock:
return self.delete_listener_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.delete_listener_safe(request, context)

def get_subsystems_safe(self, request, context):
"""Gets subsystems."""

self.logger.info(f"Received request to get subsystems")
Expand All @@ -516,3 +597,7 @@ def get_subsystems(self, request, context):
return pb2.subsystems_info()

return pb2.subsystems_info(subsystems=json.dumps(ret))

def get_subsystems(self, request, context):
with self.rpc_lock:
return self.get_subsystems_safe(request, context)
101 changes: 69 additions & 32 deletions control/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import json
import logging
import signal
import threading
import copy
from concurrent import futures
from google.protobuf import json_format

Expand Down Expand Up @@ -45,6 +47,7 @@ def sigchld_handler(signum, frame):
# GW process should exit now
raise SystemExit(f"Gateway subprocess terminated {pid=} {exit_code=}")


class GatewayServer:
"""Runs SPDK and receives client requests for the gateway service.
Expand All @@ -59,9 +62,37 @@ class GatewayServer:
discovery_pid: Subprocess running Ceph nvmeof discovery service
"""

class RPCGuard:
RPC_GUARD_LOCK_TIMEOUT = 300

def __init__(self, logger, timeout = None) -> None:
self.rpc_lock = threading.Lock()
self.lock_timeout = timeout
self.logger = logger

def __enter__(self):
rc = self.rpc_lock.acquire(True, self.lock_timeout)
if not rc:
self.logger.warning(f"Couldn't acquire lock after {self.lock_timeout} seconds, will try again")
rc = self.rpc_lock.acquire(True, self.lock_timeout)
if not rc:
self.logger.error(f"Failed to acquire lock for guarding RPC, will continue anyway")
return self

def __exit__(self, typ, value, traceback):
self.rpc_lock.release()

def is_locked(self):
return self.rpc_lock.locked()

def raise_exception_if_not_locked(self):
if not self.rpc_lock.locked():
raise Exception("RPC guard is not locked like it should be")

def __init__(self, config):
self.logger = logging.getLogger(__name__)
self.config = config
self.rpc_lock = GatewayServer.RPCGuard(self.logger, GatewayServer.RPCGuard.RPC_GUARD_LOCK_TIMEOUT)
self.spdk_process = None
self.gateway_rpc = None
self.server = None
Expand Down Expand Up @@ -113,7 +144,7 @@ def serve(self):
gateway_state = GatewayStateHandler(self.config, local_state,
omap_state, self.gateway_rpc_caller)
self.gateway_rpc = GatewayService(self.config, gateway_state,
self.spdk_rpc_client)
self.spdk_rpc_client, self.rpc_lock)
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
pb2_grpc.add_GatewayServicer_to_server(self.gateway_rpc, self.server)

Expand Down Expand Up @@ -328,45 +359,51 @@ def _ping(self):

def gateway_rpc_caller(self, requests, is_add_req):
"""Passes RPC requests to gateway service."""
for key, val in requests.items():
requests_copy = copy.deepcopy(requests)
for key, val in requests_copy.items():
if key.startswith(GatewayState.BDEV_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.create_bdev_req())
self.gateway_rpc.create_bdev(req)
else:
req = json_format.Parse(val,
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.create_bdev_req())
self.gateway_rpc.create_bdev(req)
else:
req = json_format.Parse(val,
pb2.delete_bdev_req(),
ignore_unknown_fields=True)
self.gateway_rpc.delete_bdev(req)
self.gateway_rpc.delete_bdev(req)
elif key.startswith(GatewayState.SUBSYSTEM_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.create_subsystem_req())
self.gateway_rpc.create_subsystem(req)
else:
req = json_format.Parse(val,
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.create_subsystem_req())
self.gateway_rpc.create_subsystem(req)
else:
req = json_format.Parse(val,
pb2.delete_subsystem_req(),
ignore_unknown_fields=True)
self.gateway_rpc.delete_subsystem(req)
self.gateway_rpc.delete_subsystem(req)
elif key.startswith(GatewayState.NAMESPACE_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.add_namespace_req())
self.gateway_rpc.add_namespace(req)
else:
req = json_format.Parse(val,
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.add_namespace_req())
self.gateway_rpc.add_namespace(req)
else:
req = json_format.Parse(val,
pb2.remove_namespace_req(),
ignore_unknown_fields=True)
self.gateway_rpc.remove_namespace(req)
self.gateway_rpc.remove_namespace(req)
elif key.startswith(GatewayState.HOST_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.add_host_req())
self.gateway_rpc.add_host(req)
else:
req = json_format.Parse(val, pb2.remove_host_req())
self.gateway_rpc.remove_host(req)
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.add_host_req())
self.gateway_rpc.add_host(req)
else:
req = json_format.Parse(val, pb2.remove_host_req())
self.gateway_rpc.remove_host(req)
elif key.startswith(GatewayState.LISTENER_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.create_listener_req())
self.gateway_rpc.create_listener(req)
else:
req = json_format.Parse(val, pb2.delete_listener_req())
self.gateway_rpc.delete_listener(req)
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.create_listener_req())
self.gateway_rpc.create_listener(req)
else:
req = json_format.Parse(val, pb2.delete_listener_req())
self.gateway_rpc.delete_listener(req)

0 comments on commit 316c2f9

Please sign in to comment.