Skip to content

Commit

Permalink
Make Mocket work with big requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mindflayer committed May 13, 2024
1 parent a14071e commit 3c97864
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 38 deletions.
30 changes: 9 additions & 21 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import socket
import ssl
from datetime import datetime, timedelta
from io import BytesIO
from json.decoder import JSONDecodeError

import urllib3
Expand All @@ -26,7 +27,6 @@
from .utils import (
SSL_PROTOCOL,
MocketMode,
MocketSocketCore,
get_mocketize,
hexdump,
hexload,
Expand Down Expand Up @@ -175,6 +175,8 @@ class MocketSocket:
_mode = None
_bufsize = None
_secure_socket = False
_did_handshake = False
_sent_non_empty_bytes = False

def __init__(
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
Expand All @@ -186,8 +188,6 @@ def __init__(
self.type = int(type)
self.proto = int(proto)
self._truesocket_recording_dir = None
self._did_handshake = False
self._sent_non_empty_bytes = False
self.kwargs = kwargs

def __str__(self):
Expand All @@ -202,7 +202,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
@property
def fd(self):
if self._fd is None:
self._fd = MocketSocketCore()
self._fd = BytesIO()
return self._fd

def gettimeout(self):
Expand Down Expand Up @@ -264,12 +264,10 @@ def unwrap(self):
def write(self, data):
return self.send(encode_to_bytes(data))

@staticmethod
def fileno():
if Mocket.r_fd is not None:
return Mocket.r_fd
Mocket.r_fd, Mocket.w_fd = os.pipe()
return Mocket.r_fd
def fileno(self):
if self.true_socket:
return self.true_socket.fileno()
return self.fd.fileno()

def connect(self, address):
self._address = self._host, self._port = address
Expand Down Expand Up @@ -317,8 +315,6 @@ def recv_into(self, buffer, buffersize=None, flags=None):
return len(data)

def recv(self, buffersize, flags=None):
if Mocket.r_fd and Mocket.w_fd:
return os.read(Mocket.r_fd, buffersize)
data = self.read(buffersize)
if data:
return data
Expand Down Expand Up @@ -436,7 +432,7 @@ def close(self):
self._fd = None

def __getattr__(self, name):
"""Do nothing catchall function, for methods like close() and shutdown()"""
"""Do nothing catchall function, for methods like shutdown()"""

def do_nothing(*args, **kwargs):
pass
Expand All @@ -450,8 +446,6 @@ class Mocket:
_requests = []
_namespace = text_type(id(_entries))
_truesocket_recording_dir = None
r_fd = None
w_fd = None

@classmethod
def register(cls, *entries):
Expand All @@ -473,12 +467,6 @@ def collect(cls, data):

@classmethod
def reset(cls):
if cls.r_fd is not None:
os.close(cls.r_fd)
cls.r_fd = None
if cls.w_fd is not None:
os.close(cls.w_fd)
cls.w_fd = None
cls._entries = collections.defaultdict(list)
cls._requests = []

Expand Down
16 changes: 0 additions & 16 deletions mocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import binascii
import io
import os
import ssl
from typing import TYPE_CHECKING, Any, Callable, ClassVar

Expand All @@ -12,24 +10,10 @@
if TYPE_CHECKING: # pragma: no cover
from typing import NoReturn

from _typeshed import ReadableBuffer

SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2


class MocketSocketCore(io.BytesIO):
def write( # type: ignore[override] # BytesIO returns int
self,
content: ReadableBuffer,
) -> None:
super().write(content)

from mocket import Mocket

if Mocket.r_fd and Mocket.w_fd:
os.write(Mocket.w_fd, content)


def hexdump(binary_string: bytes) -> str:
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
Expand Down
71 changes: 70 additions & 1 deletion tests/main/test_httpx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import datetime
import json

import httpx
import pytest
from asgiref.sync import async_to_sync

from mocket.mocket import Mocket, mocketize
from mocket import Mocket, async_mocketize, mocketize
from mocket.mockhttp import Entry
from mocket.plugins.httpretty import httprettified, httpretty

Expand Down Expand Up @@ -55,3 +56,71 @@ async def perform_async_transactions():

perform_async_transactions()
assert len(httpretty.latest_requests) == 1


@mocketize(strict_mode=True)
def test_sync_case():
test_uri = "https://abc.de/testdata/"
base_timestamp = int(datetime.datetime.now().timestamp())
response = [
{"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(30_000)
]
Entry.single_register(
method=Entry.POST,
uri=test_uri,
body=json.dumps(
response,
),
headers={"content-type": "application/json"},
)

with httpx.Client() as client:
response = client.post(test_uri)

assert len(response.json())


@pytest.mark.asyncio
@async_mocketize(strict_mode=True)
async def test_async_case_low_number():
test_uri = "https://abc.de/testdata/"
base_timestamp = int(datetime.datetime.now().timestamp())
response = [
{"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(100)
]
Entry.single_register(
method=Entry.POST,
uri=test_uri,
body=json.dumps(
response,
),
headers={"content-type": "application/json"},
)

async with httpx.AsyncClient() as client:
response = await client.post(test_uri)

assert len(response.json())


@pytest.mark.asyncio
@async_mocketize(strict_mode=True)
async def test_async_case_high_number():
test_uri = "https://abc.de/testdata/"
base_timestamp = int(datetime.datetime.now().timestamp())
response = [
{"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(30_000)
]
Entry.single_register(
method=Entry.POST,
uri=test_uri,
body=json.dumps(
response,
),
headers={"content-type": "application/json"},
)

async with httpx.AsyncClient() as client:
response = await client.post(test_uri)

assert len(response.json())

0 comments on commit 3c97864

Please sign in to comment.