Skip to content

Commit

Permalink
Add some type annotations, most notably to smbclient.open_file
Browse files Browse the repository at this point in the history
  • Loading branch information
mon committed Oct 15, 2024
1 parent 42804ca commit fc95d18
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 5 deletions.
193 changes: 192 additions & 1 deletion src/smbclient/_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,17 +306,208 @@ def makedirs(path, exist_ok=False, **kwargs):
create_queue.pop(-1)


# Taken from stdlib typeshed but removed the unused 'U' flag
OpenTextModeUpdating: t.TypeAlias = t.Literal[
"r+",
"+r",
"rt+",
"r+t",
"+rt",
"tr+",
"t+r",
"+tr",
"w+",
"+w",
"wt+",
"w+t",
"+wt",
"tw+",
"t+w",
"+tw",
"a+",
"+a",
"at+",
"a+t",
"+at",
"ta+",
"t+a",
"+ta",
"x+",
"+x",
"xt+",
"x+t",
"+xt",
"tx+",
"t+x",
"+tx",
]
OpenTextModeWriting: t.TypeAlias = t.Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"]
OpenTextModeReading: t.TypeAlias = t.Literal["r", "rt", "tr"]
OpenTextMode: t.TypeAlias = OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
OpenBinaryModeUpdating: t.TypeAlias = t.Literal[
"rb+",
"r+b",
"+rb",
"br+",
"b+r",
"+br",
"wb+",
"w+b",
"+wb",
"bw+",
"b+w",
"+bw",
"ab+",
"a+b",
"+ab",
"ba+",
"b+a",
"+ba",
"xb+",
"x+b",
"+xb",
"bx+",
"b+x",
"+bx",
]
OpenBinaryModeWriting: t.TypeAlias = t.Literal["wb", "bw", "ab", "ba", "xb", "bx"]
OpenBinaryModeReading: t.TypeAlias = t.Literal["rb", "br"]
OpenBinaryMode: t.TypeAlias = OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
FileType: t.TypeAlias = t.Literal["file", "dir", "pipe"]


class SMBBufferedRandom(io.BufferedRandom):
raw: SMBRawIO # type: ignore


class SMBBufferedReader(io.BufferedReader):
raw: SMBRawIO # type: ignore


class SMBBufferedWriter(io.BufferedWriter):
raw: SMBRawIO # type: ignore


# Text mode: always returns a TextIOWrapper
@t.overload
def open_file(
path,
mode: OpenTextMode = "r",
buffering=-1,
file_type: FileType = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> io.TextIOWrapper[SMBBufferedRandom | SMBBufferedReader | SMBBufferedWriter]: ...


# Unbuffered binary mode: returns a File/Directory/Pipe IO object
@t.overload
def open_file(
path,
mode: OpenBinaryMode,
buffering: t.Literal[0],
file_type: t.Literal["file"] = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> SMBFileIO: ...
@t.overload
def open_file(
path,
mode: OpenBinaryMode,
buffering: t.Literal[0],
file_type: t.Literal["dir"],
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> SMBDirectoryIO: ...
@t.overload
def open_file(
path,
mode: OpenBinaryMode,
buffering: t.Literal[0],
file_type: t.Literal["pipe"],
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> SMBPipeIO: ...


# Buffering is on: return BufferedRandom, BufferedReader, or BufferedWriter
# NOTE: This cannot handle explicit buffer sizes (>0) because of limitations
# in Python's type system
@t.overload
def open_file(
path,
mode: OpenBinaryModeUpdating,
buffering: t.Literal[-1] = -1,
file_type: FileType = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> SMBBufferedRandom: ...
@t.overload
def open_file(
path,
mode: OpenBinaryModeReading,
buffering: t.Literal[-1] = -1,
file_type: FileType = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> SMBBufferedReader: ...
@t.overload
def open_file(
path,
mode: OpenBinaryModeWriting,
buffering: t.Literal[-1] = -1,
file_type: FileType = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> SMBBufferedWriter: ...


def open_file(
path,
mode="r",
buffering=-1,
file_type: t.Literal["file", "dir", "pipe"] = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type="file",
**kwargs,
):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/smbprotocol/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Structure,
StructureField,
)
from smbprotocol.tree import TreeConnect

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -1008,7 +1009,7 @@ def __init__(self):


class Open:
def __init__(self, tree, name):
def __init__(self, tree: TreeConnect, name: str):
"""
[MS-SMB2] v53.0 2017-09-15
Expand Down
12 changes: 10 additions & 2 deletions src/smbprotocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import logging
import random
from collections import OrderedDict
from typing import Literal, Optional

import spnego
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.kbkdf import KBKDFHMAC, CounterLocation, Mode

from smbprotocol import Dialects
from smbprotocol.connection import Capabilities, Ciphers, SecurityMode
from smbprotocol.connection import Capabilities, Ciphers, Connection, SecurityMode
from smbprotocol.exceptions import (
MoreProcessingRequired,
SMBAuthenticationError,
Expand Down Expand Up @@ -170,7 +171,14 @@ def __init__(self):


class Session:
def __init__(self, connection, username=None, password=None, require_encryption=True, auth_protocol="negotiate"):
def __init__(
self,
connection: Connection,
username: Optional[str] = None,
password: Optional[str] = None,
require_encryption=True,
auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate",
):
"""
[MS-SMB2] v53.0 2017-09-15
Expand Down
3 changes: 2 additions & 1 deletion src/smbprotocol/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SMB2ValidateNegotiateInfoRequest,
SMB2ValidateNegotiateInfoResponse,
)
from smbprotocol.session import Session
from smbprotocol.structure import BytesField, EnumField, FlagField, IntField, Structure

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -204,7 +205,7 @@ def __init__(self):


class TreeConnect:
def __init__(self, session, share_name):
def __init__(self, session: Session, share_name: str):
"""
[MS-SMB2] v53.0 2017-09-15
Expand Down

0 comments on commit fc95d18

Please sign in to comment.