From 414ac66b5694336a8bd3a6cdfdd32398dc31d2c0 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 24 Sep 2024 15:11:55 +0200 Subject: [PATCH] do not use arbitraty types --- src/crawlee/_request.py | 77 ++++++++++++++++++- src/crawlee/_types.py | 45 +---------- src/crawlee/basic_crawler/_basic_crawler.py | 4 +- src/crawlee/http_clients/_base.py | 3 +- src/crawlee/http_clients/_httpx.py | 12 +-- src/crawlee/http_clients/curl_impersonate.py | 7 +- .../unit/basic_crawler/test_basic_crawler.py | 4 +- 7 files changed, 92 insertions(+), 60 deletions(-) diff --git a/src/crawlee/_request.py b/src/crawlee/_request.py index e78b4699e..ffafffb51 100644 --- a/src/crawlee/_request.py +++ b/src/crawlee/_request.py @@ -2,11 +2,12 @@ from __future__ import annotations -from collections.abc import Iterator, MutableMapping +from _collections_abc import dict_items, dict_keys +from collections.abc import Iterator, Mapping, MutableMapping from datetime import datetime from decimal import Decimal from enum import Enum -from typing import Annotated, Any, cast +from typing import Annotated, Any, cast, overload from pydantic import ( BaseModel, @@ -16,11 +17,12 @@ JsonValue, PlainSerializer, PlainValidator, + RootModel, TypeAdapter, ) from typing_extensions import Self -from crawlee._types import EnqueueStrategy, HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams +from crawlee._types import EnqueueStrategy, HttpMethod, HttpPayload, HttpQueryParams from crawlee._utils.requests import compute_unique_key, unique_key_to_request_id from crawlee._utils.urls import extract_query_params, validate_http_url @@ -96,10 +98,77 @@ def __len__(self) -> int: user_data_adapter = TypeAdapter(UserData) +class HttpHeaders(RootModel): + """An immutable mapping for HTTP headers that ensures case-insensitivity for header names.""" + + def __init__(self, headers: Mapping[str, str] | None = None) -> None: + """Create a new instance. + + Args: + headers: A mapping of header names to values. + """ + # Ensure immutability by sorting and fixing the order. + headers = headers or {} + headers = {k.lower(): v for k, v in headers.items()} + self._headers = dict(sorted(headers.items())) + + @property + def __dict__(self) -> dict[str, str]: + """Return the headers as a dictionary.""" + # We have to implement this because of `BaseModel.__iter__` implementation. + return dict(self._headers) + + @__dict__.setter + def __dict__(self, value: dict[str, str]) -> None: + """Set the headers from a dictionary.""" + self._headers = {k.lower(): v for k, v in value.items()} + + def __len__(self) -> int: + """Return the number of headers.""" + return len(self._headers) + + def __repr__(self) -> str: + """Return a string representation of the object.""" + return f'{self.__class__.__name__}({self._headers})' + + def __getitem__(self, key: str) -> str: + """Get the value of a header by its name, case-insensitive.""" + return self._headers[key.lower()] + + def __setitem__(self, key: str, value: str) -> None: + """Prevent setting a header, as the object is immutable.""" + raise TypeError(f'{self.__class__.__name__} is immutable') + + def __delitem__(self, key: str) -> None: + """Prevent deleting a header, as the object is immutable.""" + raise TypeError(f'{self.__class__.__name__} is immutable') + + def keys(self) -> dict_keys[str, str]: + """Return an iterator over the header names.""" + return self._headers.keys() + + def items(self) -> dict_items[str, str]: + """Return an iterator over the header names and values.""" + return self._headers.items() + + @overload + def get(self, key: str) -> str | None: ... + + @overload + def get(self, key: str, default: str) -> str: ... + + @overload + def get(self, key: str, default: None) -> None: ... + + def get(self, key: str, default: str | None = None) -> str | None: + """Returns the value of the header if it exists, otherwise returns the default.""" + return self._headers.get(key, default) + + class BaseRequestData(BaseModel): """Data needed to create a new crawling request.""" - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) + model_config = ConfigDict(populate_by_name=True) url: Annotated[str, BeforeValidator(validate_http_url), Field()] """URL of the web page to crawl""" diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index f9107c1ef..73ab90fbe 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Coroutine, Iterator, Mapping, Sequence from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Literal, Protocol, Union @@ -10,9 +9,10 @@ if TYPE_CHECKING: import logging import re + from collections.abc import Coroutine, Sequence from crawlee import Glob - from crawlee._request import BaseRequestData, Request + from crawlee._request import BaseRequestData, HttpHeaders, Request from crawlee.base_storage_client._models import DatasetItemsListPage from crawlee.http_clients import HttpResponse from crawlee.proxy_configuration import ProxyInfo @@ -26,7 +26,7 @@ HttpMethod: TypeAlias = Literal['GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'CONNECT', 'OPTIONS', 'TRACE', 'PATCH'] -HttpQueryParams: TypeAlias = dict[str, Any] +HttpQueryParams: TypeAlias = dict[str, str] HttpPayload: TypeAlias = dict[str, Any] @@ -221,42 +221,3 @@ async def add_requests( ) -> None: """Track a call to the `add_requests` context helper.""" self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs)) - - -class HttpHeaders(Mapping[str, str]): - """An immutable mapping for HTTP headers that ensures case-insensitivity for header names.""" - - def __init__(self, headers: Mapping[str, str] | None = None) -> None: - """Create a new instance. - - Args: - headers: A mapping of header names to values. - """ - # Ensure immutability by sorting and fixing the order. - headers = headers or {} - headers = {k.lower(): v for k, v in headers.items()} - self._headers = dict(sorted(headers.items())) - - def __getitem__(self, key: str) -> str: - """Get the value of a header by its name, case-insensitive.""" - return self._headers[key.lower()] - - def __iter__(self) -> Iterator[str]: - """Return an iterator over the header names.""" - return iter(self._headers) - - def __len__(self) -> int: - """Return the number of headers.""" - return len(self._headers) - - def __repr__(self) -> str: - """Return a string representation of the object.""" - return f'{self.__class__.__name__}({self._headers})' - - def __setitem__(self, key: str, value: str) -> None: - """Prevent setting a header, as the object is immutable.""" - raise TypeError(f'{self.__class__.__name__} is immutable') - - def __delitem__(self, key: str) -> None: - """Prevent deleting a header, as the object is immutable.""" - raise TypeError(f'{self.__class__.__name__} is immutable') diff --git a/src/crawlee/basic_crawler/_basic_crawler.py b/src/crawlee/basic_crawler/_basic_crawler.py index 212649883..ebcbf67bf 100644 --- a/src/crawlee/basic_crawler/_basic_crawler.py +++ b/src/crawlee/basic_crawler/_basic_crawler.py @@ -23,8 +23,8 @@ from crawlee._autoscaling.snapshotter import Snapshotter from crawlee._autoscaling.system_status import SystemStatus from crawlee._log_config import configure_logger, get_configured_log_level -from crawlee._request import BaseRequestData, Request, RequestState -from crawlee._types import BasicCrawlingContext, HttpHeaders, RequestHandlerRunResult, SendRequestFunction +from crawlee._request import BaseRequestData, HttpHeaders, Request, RequestState +from crawlee._types import BasicCrawlingContext, RequestHandlerRunResult, SendRequestFunction from crawlee._utils.byte_size import ByteSize from crawlee._utils.http import is_status_code_client_error from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute diff --git a/src/crawlee/http_clients/_base.py b/src/crawlee/http_clients/_base.py index c754a3c7f..5ad592c93 100644 --- a/src/crawlee/http_clients/_base.py +++ b/src/crawlee/http_clients/_base.py @@ -10,7 +10,8 @@ if TYPE_CHECKING: from collections.abc import Iterable - from crawlee._types import HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams + from crawlee._request import HttpHeaders + from crawlee._types import HttpMethod, HttpPayload, HttpQueryParams from crawlee.base_storage_client._models import Request from crawlee.proxy_configuration import ProxyInfo from crawlee.sessions import Session diff --git a/src/crawlee/http_clients/_httpx.py b/src/crawlee/http_clients/_httpx.py index f0636212c..b0c7f4e2e 100644 --- a/src/crawlee/http_clients/_httpx.py +++ b/src/crawlee/http_clients/_httpx.py @@ -6,7 +6,7 @@ import httpx from typing_extensions import override -from crawlee._types import HttpHeaders, HttpPayload, HttpQueryParams +from crawlee._request import HttpHeaders from crawlee._utils.blocked import ROTATE_PROXY_ERRORS from crawlee.errors import ProxyError from crawlee.fingerprint_suite import HeaderGenerator @@ -16,7 +16,7 @@ if TYPE_CHECKING: from collections.abc import Iterable - from crawlee._types import HttpMethod + from crawlee._types import HttpMethod, HttpPayload, HttpQueryParams from crawlee.base_storage_client._models import Request from crawlee.proxy_configuration import ProxyInfo from crawlee.statistics import Statistics @@ -125,12 +125,12 @@ async def crawl( statistics: Statistics | None = None, ) -> HttpCrawlingResult: client = self._get_client(proxy_info.url if proxy_info else None) - headers = self._combine_headers(HttpHeaders(request.headers)) + headers = self._combine_headers(request.headers) http_request = client.build_request( url=request.url, method=request.method, - headers=headers, + headers=dict(headers) if headers else None, params=request.query_params, data=request.payload, cookies=session.cookies if session else None, @@ -177,7 +177,7 @@ async def send_request( http_request = client.build_request( url=url, method=method, - headers=headers, + headers=dict(headers) if headers else None, params=query_params, data=payload, extensions={'crawlee_session': session if self._persist_cookies_per_session else None}, @@ -230,7 +230,7 @@ def _combine_headers(self, explicit_headers: HttpHeaders | None) -> HttpHeaders headers = HttpHeaders(common_headers) if explicit_headers: - headers = HttpHeaders({**headers, **explicit_headers}) + headers = HttpHeaders({**dict(headers), **dict(headers)}) return headers if headers else None diff --git a/src/crawlee/http_clients/curl_impersonate.py b/src/crawlee/http_clients/curl_impersonate.py index a360a847c..da08503ce 100644 --- a/src/crawlee/http_clients/curl_impersonate.py +++ b/src/crawlee/http_clients/curl_impersonate.py @@ -25,7 +25,8 @@ from curl_cffi.requests import Response - from crawlee._types import HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams + from crawlee._request import HttpHeaders + from crawlee._types import HttpMethod, HttpPayload, HttpQueryParams from crawlee.base_storage_client._models import Request from crawlee.proxy_configuration import ProxyInfo from crawlee.sessions import Session @@ -118,7 +119,7 @@ async def crawl( response = await client.request( url=request.url, method=request.method.upper(), # type: ignore # curl-cffi requires uppercase method - headers=request.headers, + headers=dict(request.headers) if request.headers else None, params=request.query_params, data=request.payload, cookies=session.cookies if session else None, @@ -163,7 +164,7 @@ async def send_request( response = await client.request( url=url, method=method.upper(), # type: ignore # curl-cffi requires uppercase method - headers=headers, + headers=dict(headers) if headers else None, params=query_params, data=payload, cookies=session.cookies if session else None, diff --git a/tests/unit/basic_crawler/test_basic_crawler.py b/tests/unit/basic_crawler/test_basic_crawler.py index 5bc9c68f3..3c60bfcd9 100644 --- a/tests/unit/basic_crawler/test_basic_crawler.py +++ b/tests/unit/basic_crawler/test_basic_crawler.py @@ -15,8 +15,8 @@ import pytest from crawlee import ConcurrencySettings, EnqueueStrategy, Glob -from crawlee._request import BaseRequestData, Request -from crawlee._types import AddRequestsKwargs, BasicCrawlingContext, HttpHeaders +from crawlee._request import BaseRequestData, HttpHeaders, Request +from crawlee._types import AddRequestsKwargs, BasicCrawlingContext from crawlee.basic_crawler import BasicCrawler from crawlee.configuration import Configuration from crawlee.errors import SessionError, UserDefinedErrorHandlerError