diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 95ca70e..686b6ad 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -42,7 +42,10 @@ def __init__(self, self.listenport = None self.sock = None # FIXME: What about multiple roots? - self.root = os.path.abspath(tftproot) + if tftproot != None: + self.root = os.path.abspath(tftproot) + else: + self.root = None self.dyn_file_func = dyn_file_func self.upload_open = upload_open # A dict of sessions, where each session is keyed by a string like @@ -59,22 +62,26 @@ def __init__(self, attr = getattr(self, name) if attr and not callable(attr): raise TftpException("{} supplied, but it is not callable.".format(name)) - if os.path.exists(self.root): - log.debug("tftproot %s does exist", self.root) - if not os.path.isdir(self.root): - raise TftpException("The tftproot must be a directory.") - else: - log.debug("tftproot %s is a directory" % self.root) - if os.access(self.root, os.R_OK): - log.debug("tftproot %s is readable" % self.root) - else: - raise TftpException("The tftproot must be readable") - if os.access(self.root, os.W_OK): - log.debug("tftproot %s is writable" % self.root) + if self.root != None: + if os.path.exists(self.root): + log.debug("tftproot %s does exist", self.root) + if not os.path.isdir(self.root): + raise TftpException("The tftproot must be a directory.") else: - log.warning("The tftproot %s is not writable" % self.root) + log.debug("tftproot %s is a directory" % self.root) + if os.access(self.root, os.R_OK): + log.debug("tftproot %s is readable" % self.root) + else: + raise TftpException("The tftproot must be readable") + if os.access(self.root, os.W_OK): + log.debug("tftproot %s is writable" % self.root) + else: + log.warning("The tftproot %s is not writable" % self.root) + else: + raise TftpException("The tftproot does not exist.") else: - raise TftpException("The tftproot does not exist.") + if dyn_file_func == None: + raise TftpException("No tftproot and no dyn_file_func given") def listen(self, listenip="", listenport=DEF_TFTP_PORT, timeout=SOCK_TIMEOUT): diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 42bac1d..a7530b2 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -258,33 +258,33 @@ def serverInitial(self, pkt, raddress, rport): log.debug("Requested filename is %s", pkt.filename) - # Build the filename on this server and ensure it is contained - # in the specified root directory. - # - # Filenames that begin with server root are accepted. It's - # assumed the client and server are tightly connected and this - # provides backwards compatibility. - # - # Filenames otherwise are relative to the server root. If they - # begin with a '/' strip it off as otherwise os.path.join will - # treat it as absolute (regardless of whether it is ntpath or - # posixpath module - if pkt.filename.startswith(self.context.root): - full_path = pkt.filename - else: - full_path = os.path.join(self.context.root, pkt.filename.lstrip('/')) - - # Use abspath to eliminate any remaining relative elements - # (e.g. '..') and ensure that is still within the server's - # root directory - self.full_path = os.path.abspath(full_path) - log.debug("full_path is %s", full_path) - if self.full_path.startswith(self.context.root): - log.info("requested file is in the server root - good") - else: - log.warning("requested file is not within the server root - bad") - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException("bad file path") + if self.context.root != None: + # Build the filename on this server and ensure it is contained + # in the specified root directory. + # + # Filenames that begin with server root are accepted. It's + # assumed the client and server are tightly connected and this + # provides backwards compatibility. + # + # Filenames otherwise are relative to the server root. If they + # begin with a '/' strip it off as otherwise os.path.join will + # treat it as absolute (regardless of whether it is ntpath or + # posixpath module + if pkt.filename.startswith(self.context.root): + full_path = pkt.filename + else: + full_path = os.path.join(self.context.root, pkt.filename.lstrip('/')) + # Use abspath to eliminate any remaining relative elements + # (e.g. '..') and ensure that is still within the server's + # root directory + self.full_path = os.path.abspath(full_path) + log.debug("full_path is %s", full_path) + if self.full_path.startswith(self.context.root): + log.info("requested file is in the server root - good") + else: + log.warning("requested file is not within the server root - bad") + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException("bad file path") self.context.file_to_transfer = pkt.filename @@ -300,7 +300,7 @@ def handle(self, pkt, raddress, rport): sendoack = self.serverInitial(pkt, raddress, rport) path = self.full_path log.info("Opening file %s for reading" % path) - if os.path.exists(path): + if path != None and os.path.exists(path): # Note: Open in binary mode for win32 portability, since win32 # blows. self.context.fileobj = open(path, "rb")