Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression #239

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import socket
import ssl
from datetime import datetime, timedelta
from io import BytesIO
from json.decoder import JSONDecodeError
from typing import Optional, Tuple

import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
Expand All @@ -27,6 +27,7 @@
from .utils import (
SSL_PROTOCOL,
MocketMode,
MocketSocketCore,
get_mocketize,
hexdump,
hexload,
Expand Down Expand Up @@ -73,15 +74,15 @@


class SuperFakeSSLContext:
"""For Python 3.6"""
"""For Python 3.6 and newer."""

class FakeSetter(int):
def __set__(self, *args):
pass

minimum_version = FakeSetter()
options = FakeSetter()
verify_mode = FakeSetter(ssl.CERT_NONE)
verify_mode = FakeSetter()


class FakeSSLContext(SuperFakeSSLContext):
Expand Down Expand Up @@ -177,6 +178,7 @@ class MocketSocket:
_secure_socket = False
_did_handshake = False
_sent_non_empty_bytes = False
_io = None

def __init__(
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
Expand All @@ -200,10 +202,18 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@property
def fd(self):
if self._fd is None:
self._fd = BytesIO()
return self._fd
def io(self):
if self._io is None:
self._io = MocketSocketCore((self._host, self._port))
return self._io

def fileno(self):
address = (self._host, self._port)
r_fd, _ = Mocket.get_pair(address)
if not r_fd:
r_fd, w_fd = os.pipe()
Mocket.set_pair(address, (r_fd, w_fd))
return r_fd

def gettimeout(self):
return self.timeout
Expand Down Expand Up @@ -264,19 +274,14 @@ def unwrap(self):
def write(self, data):
return self.send(encode_to_bytes(data))

def fileno(self):
if self.true_socket:
return self.true_socket.fileno()
return self.fd.fileno()

def connect(self, address):
self._address = self._host, self._port = address
Mocket._address = address

def makefile(self, mode="r", bufsize=-1):
self._mode = mode
self._bufsize = bufsize
return self.fd
return self.io

def get_entry(self, data):
return Mocket.get_entry(self._host, self._port, data)
Expand All @@ -292,13 +297,13 @@ def sendall(self, data, entry=None, *args, **kwargs):
response = self.true_sendall(data, *args, **kwargs)

if response is not None:
self.fd.seek(0)
self.fd.write(response)
self.fd.truncate()
self.fd.seek(0)
self.io.seek(0)
self.io.write(response)
self.io.truncate()
self.io.seek(0)

def read(self, buffersize):
rv = self.fd.read(buffersize)
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
Expand All @@ -315,6 +320,9 @@ def recv_into(self, buffer, buffersize=None, flags=None):
return len(data)

def recv(self, buffersize, flags=None):
r_fd, _ = Mocket.get_pair((self._host, self._port))
if r_fd:
return os.read(r_fd, buffersize)
data = self.read(buffersize)
if data:
return data
Expand Down Expand Up @@ -416,8 +424,8 @@ def true_sendall(self, data, *args, **kwargs):

def send(self, data, *args, **kwargs): # pragma: no cover
entry = self.get_entry(data)
kwargs["entry"] = entry
if not entry or (entry and self._entry != entry):
kwargs["entry"] = entry
self.sendall(data, *args, **kwargs)
else:
req = Mocket.last_request()
Expand All @@ -441,12 +449,29 @@ def do_nothing(*args, **kwargs):


class Mocket:
_socket_pairs = {}
_address = (None, None)
_entries = collections.defaultdict(list)
_requests = []
_namespace = text_type(id(_entries))
_truesocket_recording_dir = None

@classmethod
def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
"""
Given the id() of the caller, return a pair of file descriptors
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
return cls._socket_pairs.get(address, (None, None))

@classmethod
def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
"""
Store a pair of file descriptors under the key `id_`
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
cls._socket_pairs[address] = pair

@classmethod
def register(cls, *entries):
for entry in entries:
Expand All @@ -467,6 +492,10 @@ def collect(cls, data):

@classmethod
def reset(cls):
for r_fd, w_fd in cls._socket_pairs.values():
os.close(r_fd)
os.close(w_fd)
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []

Expand Down
2 changes: 1 addition & 1 deletion mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def can_handle(self, data):
"""
try:
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
method, path, version = self._parse_requestline(requestline)
method, path, _ = self._parse_requestline(requestline)
except ValueError:
return self is getattr(Mocket, "_last_entry", None)

Expand Down
17 changes: 17 additions & 0 deletions mocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import binascii
import io
import os
import ssl
from typing import TYPE_CHECKING, Any, Callable, ClassVar

Expand All @@ -14,6 +16,21 @@
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2


class MocketSocketCore(io.BytesIO):
def __init__(self, address) -> None:
self._address = address
super().__init__()

def write(self, content):
from mocket import Mocket

super().write(content)

_, w_fd = Mocket.get_pair(self._address)
if w_fd:
os.write(w_fd, content)


def hexdump(binary_string: bytes) -> str:
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
Expand Down
5 changes: 0 additions & 5 deletions tests/main/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import glob
import json
import socket
import sys
import tempfile

import aiohttp
Expand Down Expand Up @@ -45,10 +44,6 @@ async def test_asyncio_connection():


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Looks like https://github.com/aio-libs/aiohttp/issues/5582",
)
@async_mocketize
async def test_aiohttp():
url = "https://bar.foo/"
Expand Down