From f6c8f82f791ef0917886ed4d0d628db9a724cefb Mon Sep 17 00:00:00 2001 From: Will Toohey Date: Thu, 17 Oct 2024 10:22:28 +1000 Subject: [PATCH] Add some type annotations, most notably to `smbclient.open_file` --- src/smbclient/_os.py | 136 ++++++++++++++++++++++++++++++++++++- src/smbclient/_pool.py | 11 +-- src/smbprotocol/session.py | 10 ++- 3 files changed, 150 insertions(+), 7 deletions(-) diff --git a/src/smbclient/_os.py b/src/smbclient/_os.py index 88917c6..3c91f93 100644 --- a/src/smbclient/_os.py +++ b/src/smbclient/_os.py @@ -306,6 +306,140 @@ def makedirs(path, exist_ok=False, **kwargs): create_queue.pop(-1) +# Taken from stdlib typeshed but removed the unused 'U' flag +OpenTextModeUpdating: t.TypeAlias = t.Literal[ + "r+", + "+r", + "rt+", + "r+t", + "+rt", + "tr+", + "t+r", + "+tr", + "w+", + "+w", + "wt+", + "w+t", + "+wt", + "tw+", + "t+w", + "+tw", + "a+", + "+a", + "at+", + "a+t", + "+at", + "ta+", + "t+a", + "+ta", + "x+", + "+x", + "xt+", + "x+t", + "+xt", + "tx+", + "t+x", + "+tx", +] +OpenTextModeWriting: t.TypeAlias = t.Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"] +OpenTextModeReading: t.TypeAlias = t.Literal["r", "rt", "tr"] +OpenTextMode: t.TypeAlias = OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading +OpenBinaryModeUpdating: t.TypeAlias = t.Literal[ + "rb+", + "r+b", + "+rb", + "br+", + "b+r", + "+br", + "wb+", + "w+b", + "+wb", + "bw+", + "b+w", + "+bw", + "ab+", + "a+b", + "+ab", + "ba+", + "b+a", + "+ba", + "xb+", + "x+b", + "+xb", + "bx+", + "b+x", + "+bx", +] +OpenBinaryModeWriting: t.TypeAlias = t.Literal["wb", "bw", "ab", "ba", "xb", "bx"] +OpenBinaryModeReading: t.TypeAlias = t.Literal["rb", "br"] +OpenBinaryMode: t.TypeAlias = OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting +FileType: t.TypeAlias = t.Literal["file", "dir", "pipe"] + + +# Text mode: always returns a TextIOWrapper +@t.overload +def open_file( + path, + mode: OpenTextMode = "r", + buffering=-1, + file_type: FileType = "file", + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> io.TextIOWrapper[io.BufferedRandom | io.BufferedReader | io.BufferedWriter]: ... + + +# Otherwise return BufferedRandom, BufferedReader, or BufferedWriter +# NOTE: This incorrectly returns unbuffered opens as Buffered types, due to difficulties +# in annotating that case +@t.overload +def open_file( + path, + mode: OpenBinaryModeUpdating, + buffering=-1, + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + file_type: FileType = "file", + **kwargs, +) -> io.BufferedRandom: ... +@t.overload +def open_file( + path, + mode: OpenBinaryModeReading, + buffering=-1, + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + file_type: FileType = "file", + **kwargs, +) -> io.BufferedReader: ... +@t.overload +def open_file( + path, + mode: OpenBinaryModeWriting, + buffering=-1, + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + file_type: FileType = "file", + **kwargs, +) -> io.BufferedWriter: ... + + def open_file( path, mode="r", @@ -316,7 +450,7 @@ def open_file( share_access=None, desired_access=None, file_attributes=None, - file_type="file", + file_type: t.Literal["file", "dir", "pipe"] = "file", **kwargs, ): """ diff --git a/src/smbclient/_pool.py b/src/smbclient/_pool.py index 06f06f1..dfe79d2 100644 --- a/src/smbclient/_pool.py +++ b/src/smbclient/_pool.py @@ -7,6 +7,7 @@ import logging import ntpath import uuid +from typing import Literal, Optional from smbprotocol._text import to_text from smbprotocol.connection import Capabilities, Connection @@ -366,14 +367,14 @@ def get_smb_tree( def register_session( - server, - username=None, - password=None, + server: str, + username: Optional[str] = None, + password: Optional[str] = None, port=445, - encrypt=None, + encrypt: Optional[bool] = None, connection_timeout=60, connection_cache=None, - auth_protocol="negotiate", + auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate", require_signing=True, ): """ diff --git a/src/smbprotocol/session.py b/src/smbprotocol/session.py index ac99a02..3dcd543 100644 --- a/src/smbprotocol/session.py +++ b/src/smbprotocol/session.py @@ -5,6 +5,7 @@ import logging import random from collections import OrderedDict +from typing import Literal, Optional import spnego from cryptography.hazmat.backends import default_backend @@ -170,7 +171,14 @@ def __init__(self): class Session: - def __init__(self, connection, username=None, password=None, require_encryption=True, auth_protocol="negotiate"): + def __init__( + self, + connection, + username: Optional[str] = None, + password: Optional[str] = None, + require_encryption=True, + auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate", + ): """ [MS-SMB2] v53.0 2017-09-15