Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External contribution #219

Merged
merged 9 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@
except ImportError:
pyopenssl_override = False

try: # pragma: no cover
from aiohttp import TCPConnector

aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
except (ImportError, AttributeError):
aiohttp_make_ssl_context_cache_clear = None


true_socket = socket.socket
true_create_connection = socket.create_connection
Expand Down Expand Up @@ -85,6 +92,7 @@ class FakeSSLContext(SuperFakeSSLContext):
"load_verify_locations",
"set_alpn_protocols",
"set_ciphers",
"set_default_verify_paths",
)
sock = None
post_handshake_auth = None
Expand Down Expand Up @@ -180,6 +188,8 @@ def __init__(
self.type = int(type)
self.proto = int(proto)
self._truesocket_recording_dir = None
self._did_handshake = False
self._sent_non_empty_bytes = False
self.kwargs = kwargs

def __str__(self):
Expand Down Expand Up @@ -218,7 +228,7 @@ def getsockopt(level, optname, buflen=None):
return socket.SOCK_STREAM

def do_handshake(self):
pass
self._did_handshake = True

def getpeername(self):
return self._address
Expand Down Expand Up @@ -257,6 +267,8 @@ def write(self, data):

@staticmethod
def fileno():
if Mocket.r_fd is not None:
return Mocket.r_fd
Mocket.r_fd, Mocket.w_fd = os.pipe()
return Mocket.r_fd

Expand Down Expand Up @@ -292,10 +304,21 @@ def sendall(self, data, entry=None, *args, **kwargs):
self.fd.seek(0)

def read(self, buffersize):
return self.fd.read(buffersize)
rv = self.fd.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
raise ssl.SSLWantReadError("The operation did not complete (read)")
return rv

def recv_into(self, buffer, buffersize=None, flags=None):
return buffer.write(self.read(buffersize))
if hasattr(buffer, "write"):
return buffer.write(self.read(buffersize))
# buffer is a memoryview
data = self.read(buffersize)
if data:
buffer[: len(data)] = data
return len(data)

def recv(self, buffersize, flags=None):
if Mocket.r_fd and Mocket.w_fd:
Expand Down Expand Up @@ -455,8 +478,12 @@ def collect(cls, data):

@classmethod
def reset(cls):
cls.r_fd = None
cls.w_fd = None
if cls.r_fd is not None:
os.close(cls.r_fd)
cls.r_fd = None
if cls.w_fd is not None:
os.close(cls.w_fd)
cls.w_fd = None
cls._entries = collections.defaultdict(list)
cls._requests = []

Expand Down Expand Up @@ -527,6 +554,8 @@ def enable(namespace=None, truesocket_recording_dir=None):
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()
if aiohttp_make_ssl_context_cache_clear: # pragma: no cover
aiohttp_make_ssl_context_cache_clear()

@staticmethod
def disable():
Expand Down Expand Up @@ -563,6 +592,8 @@ def disable():
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
if aiohttp_make_ssl_context_cache_clear: # pragma: no cover
aiohttp_make_ssl_context_cache_clear()

@classmethod
def get_namespace(cls):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dynamic = ["version"]
[project.optional-dependencies]
test = [
"pre-commit",
"psutil",
"pytest",
"pytest-cov",
"pytest-asyncio",
Expand Down
18 changes: 18 additions & 0 deletions tests/main/test_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from unittest import TestCase
from unittest.mock import patch

import httpx
import psutil
import pytest

from mocket import Mocket, MocketEntry, Mocketizer, mocketize
Expand Down Expand Up @@ -190,3 +192,19 @@ def test_patch(
):
method_patch.return_value = "foo"
assert os.getcwd() == "foo"


@pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test")
@pytest.mark.asyncio
async def test_no_dangling_fds():
url = "http://httpbin.local/ip"

proc = psutil.Process(os.getpid())

prev_num_fds = proc.num_fds()

async with Mocketizer(strict_mode=False):
async with httpx.AsyncClient() as client:
await client.get(url)

assert proc.num_fds() == prev_num_fds
59 changes: 38 additions & 21 deletions tests/tests38/test_http_aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from unittest import IsolatedAsyncioTestCase

import httpx
import pytest

from mocket.async_mocket import async_mocketize
from mocket.mocket import Mocket
from mocket.mocket import Mocket, Mocketizer
from mocket.mockhttp import Entry
from mocket.plugins.httpretty import HTTPretty, async_httprettified

Expand Down Expand Up @@ -46,6 +45,23 @@ async def test_http_session(self):

self.assertEqual(len(Mocket.request_list()), 2)

@async_httprettified
async def test_httprettish_session(self):
HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(dict(origin="127.0.0.1")),
)

async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.get(self.target_url) as get_response:
assert get_response.status == 200
assert await get_response.text() == '{"origin": "127.0.0.1"}'

class AioHttpsEntryTestCase(IsolatedAsyncioTestCase):
timeout = aiohttp.ClientTimeout(total=3)
target_url = "https://httpbin.localhost/anything/"

@async_mocketize
async def test_https_session(self):
body = "asd" * 100
Expand All @@ -67,7 +83,14 @@ async def test_https_session(self):

self.assertEqual(len(Mocket.request_list()), 2)

@pytest.mark.xfail
@async_mocketize
async def test_no_verify(self):
Entry.single_register(Entry.GET, self.target_url, status=404)

async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.get(self.target_url, ssl=False) as get_response:
assert get_response.status == 404

@async_httprettified
async def test_httprettish_session(self):
HTTPretty.register_uri(
Expand All @@ -81,21 +104,15 @@ async def test_httprettish_session(self):
assert get_response.status == 200
assert await get_response.text() == '{"origin": "127.0.0.1"}'


class HttpxEntryTestCase(IsolatedAsyncioTestCase):
target_url = "http://httpbin.local/ip"

@async_httprettified
async def test_httprettish_httpx_session(self):
expected_response = {"origin": "127.0.0.1"}

HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(expected_response),
)

async with httpx.AsyncClient() as client:
response = await client.get(self.target_url)
assert response.status_code == 200
assert response.json() == expected_response
@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)')
async def test_mocked_https_request_after_unmocked_https_request(self):
async with aiohttp.ClientSession(timeout=self.timeout) as session:
response = await session.get(self.target_url + "real", ssl=False)
assert response.status == 200

async with Mocketizer(None):
Entry.single_register(Entry.GET, self.target_url + "mocked", status=404)
async with aiohttp.ClientSession(timeout=self.timeout) as session:
response = await session.get(self.target_url + "mocked", ssl=False)
assert response.status == 404
self.assertEqual(len(Mocket.request_list()), 1)
44 changes: 44 additions & 0 deletions tests/tests38/test_http_httpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json
from unittest import IsolatedAsyncioTestCase

import httpx

from mocket.plugins.httpretty import HTTPretty, async_httprettified


class HttpxEntryTestCase(IsolatedAsyncioTestCase):
target_url = "http://httpbin.local/ip"

@async_httprettified
async def test_httprettish_httpx_session(self):
expected_response = {"origin": "127.0.0.1"}

HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(expected_response),
)

async with httpx.AsyncClient() as client:
response = await client.get(self.target_url)
assert response.status_code == 200
assert response.json() == expected_response


class HttpxHttpsEntryTestCase(IsolatedAsyncioTestCase):
target_url = "https://httpbin.local/ip"

@async_httprettified
async def test_httprettish_httpx_session(self):
expected_response = {"origin": "127.0.0.1"}

HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(expected_response),
)

async with httpx.AsyncClient() as client:
response = await client.get(self.target_url)
assert response.status_code == 200
assert response.json() == expected_response