diff --git a/src/smbclient/_os.py b/src/smbclient/_os.py index 88917c6..339a642 100644 --- a/src/smbclient/_os.py +++ b/src/smbclient/_os.py @@ -306,17 +306,208 @@ def makedirs(path, exist_ok=False, **kwargs): create_queue.pop(-1) +# Taken from stdlib typeshed but removed the unused 'U' flag +OpenTextModeUpdating: t.TypeAlias = t.Literal[ + "r+", + "+r", + "rt+", + "r+t", + "+rt", + "tr+", + "t+r", + "+tr", + "w+", + "+w", + "wt+", + "w+t", + "+wt", + "tw+", + "t+w", + "+tw", + "a+", + "+a", + "at+", + "a+t", + "+at", + "ta+", + "t+a", + "+ta", + "x+", + "+x", + "xt+", + "x+t", + "+xt", + "tx+", + "t+x", + "+tx", +] +OpenTextModeWriting: t.TypeAlias = t.Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"] +OpenTextModeReading: t.TypeAlias = t.Literal["r", "rt", "tr"] +OpenTextMode: t.TypeAlias = OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading +OpenBinaryModeUpdating: t.TypeAlias = t.Literal[ + "rb+", + "r+b", + "+rb", + "br+", + "b+r", + "+br", + "wb+", + "w+b", + "+wb", + "bw+", + "b+w", + "+bw", + "ab+", + "a+b", + "+ab", + "ba+", + "b+a", + "+ba", + "xb+", + "x+b", + "+xb", + "bx+", + "b+x", + "+bx", +] +OpenBinaryModeWriting: t.TypeAlias = t.Literal["wb", "bw", "ab", "ba", "xb", "bx"] +OpenBinaryModeReading: t.TypeAlias = t.Literal["rb", "br"] +OpenBinaryMode: t.TypeAlias = OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting +FileType: t.TypeAlias = t.Literal["file", "dir", "pipe"] + + +class SMBBufferedRandom(io.BufferedRandom): + raw: SMBRawIO # type: ignore + + +class SMBBufferedReader(io.BufferedReader): + raw: SMBRawIO # type: ignore + + +class SMBBufferedWriter(io.BufferedWriter): + raw: SMBRawIO # type: ignore + + +# Text mode: always returns a TextIOWrapper +@t.overload +def open_file( + path, + mode: OpenTextMode = "r", + buffering=-1, + file_type: FileType = "file", + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> io.TextIOWrapper[SMBBufferedRandom | SMBBufferedReader | SMBBufferedWriter]: ... + + +# Unbuffered binary mode: returns a File/Directory/Pipe IO object +@t.overload +def open_file( + path, + mode: OpenBinaryMode, + buffering: t.Literal[0], + file_type: t.Literal["file"] = "file", + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> SMBFileIO: ... +@t.overload +def open_file( + path, + mode: OpenBinaryMode, + buffering: t.Literal[0], + file_type: t.Literal["dir"], + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> SMBDirectoryIO: ... +@t.overload +def open_file( + path, + mode: OpenBinaryMode, + buffering: t.Literal[0], + file_type: t.Literal["pipe"], + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> SMBPipeIO: ... + + +# Buffering is on: return BufferedRandom, BufferedReader, or BufferedWriter +# NOTE: This cannot handle explicit buffer sizes (>0) because of limitations +# in Python's type system +@t.overload +def open_file( + path, + mode: OpenBinaryModeUpdating, + buffering: t.Literal[-1] = -1, + file_type: FileType = "file", + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> SMBBufferedRandom: ... +@t.overload +def open_file( + path, + mode: OpenBinaryModeReading, + buffering: t.Literal[-1] = -1, + file_type: FileType = "file", + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> SMBBufferedReader: ... +@t.overload +def open_file( + path, + mode: OpenBinaryModeWriting, + buffering: t.Literal[-1] = -1, + file_type: FileType = "file", + encoding=None, + errors=None, + newline=None, + share_access=None, + desired_access=None, + file_attributes=None, + **kwargs, +) -> SMBBufferedWriter: ... + + def open_file( path, mode="r", buffering=-1, + file_type: t.Literal["file", "dir", "pipe"] = "file", encoding=None, errors=None, newline=None, share_access=None, desired_access=None, file_attributes=None, - file_type="file", **kwargs, ): """ diff --git a/src/smbprotocol/open.py b/src/smbprotocol/open.py index 14467d4..644457a 100644 --- a/src/smbprotocol/open.py +++ b/src/smbprotocol/open.py @@ -31,6 +31,7 @@ Structure, StructureField, ) +from smbprotocol.tree import TreeConnect log = logging.getLogger(__name__) @@ -1008,7 +1009,7 @@ def __init__(self): class Open: - def __init__(self, tree, name): + def __init__(self, tree: TreeConnect, name: str): """ [MS-SMB2] v53.0 2017-09-15 diff --git a/src/smbprotocol/session.py b/src/smbprotocol/session.py index ac99a02..5b59183 100644 --- a/src/smbprotocol/session.py +++ b/src/smbprotocol/session.py @@ -5,6 +5,7 @@ import logging import random from collections import OrderedDict +from typing import Literal, Optional import spnego from cryptography.hazmat.backends import default_backend @@ -12,7 +13,7 @@ from cryptography.hazmat.primitives.kdf.kbkdf import KBKDFHMAC, CounterLocation, Mode from smbprotocol import Dialects -from smbprotocol.connection import Capabilities, Ciphers, SecurityMode +from smbprotocol.connection import Capabilities, Ciphers, Connection, SecurityMode from smbprotocol.exceptions import ( MoreProcessingRequired, SMBAuthenticationError, @@ -170,7 +171,14 @@ def __init__(self): class Session: - def __init__(self, connection, username=None, password=None, require_encryption=True, auth_protocol="negotiate"): + def __init__( + self, + connection: Connection, + username: Optional[str] = None, + password: Optional[str] = None, + require_encryption=True, + auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate", + ): """ [MS-SMB2] v53.0 2017-09-15 diff --git a/src/smbprotocol/tree.py b/src/smbprotocol/tree.py index 7383aeb..3383c20 100644 --- a/src/smbprotocol/tree.py +++ b/src/smbprotocol/tree.py @@ -20,6 +20,7 @@ SMB2ValidateNegotiateInfoRequest, SMB2ValidateNegotiateInfoResponse, ) +from smbprotocol.session import Session from smbprotocol.structure import BytesField, EnumField, FlagField, IntField, Structure log = logging.getLogger(__name__) @@ -204,7 +205,7 @@ def __init__(self): class TreeConnect: - def __init__(self, session, share_name): + def __init__(self, session: Session, share_name: str): """ [MS-SMB2] v53.0 2017-09-15