diff options
Diffstat (limited to 'searx/network')
| -rw-r--r-- | searx/network/__init__.py | 58 | ||||
| -rw-r--r-- | searx/network/client.py | 76 | ||||
| -rw-r--r-- | searx/network/network.py | 92 | ||||
| -rw-r--r-- | searx/network/raise_for_httperror.py | 16 |
4 files changed, 137 insertions, 105 deletions
diff --git a/searx/network/__init__.py b/searx/network/__init__.py index 6230b9e39..070388d2e 100644 --- a/searx/network/__init__.py +++ b/searx/network/__init__.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # pylint: disable=missing-module-docstring, global-statement +__all__ = ["initialize", "check_network_configuration", "raise_for_httperror"] + +import typing as t + import asyncio import threading import concurrent.futures from queue import SimpleQueue from types import MethodType from timeit import default_timer -from typing import Iterable, NamedTuple, Tuple, List, Dict, Union +from collections.abc import Iterable from contextlib import contextmanager import httpx @@ -32,12 +36,12 @@ def get_time_for_thread(): return THREADLOCAL.__dict__.get('total_time') -def set_timeout_for_thread(timeout, start_time=None): +def set_timeout_for_thread(timeout: float, start_time: float | None = None): THREADLOCAL.timeout = timeout THREADLOCAL.start_time = start_time -def set_context_network_name(network_name): +def set_context_network_name(network_name: str): THREADLOCAL.network = get_network(network_name) @@ -64,9 +68,10 @@ def _record_http_time(): THREADLOCAL.total_time += time_after_request - time_before_request -def _get_timeout(start_time, kwargs): +def _get_timeout(start_time: float, kwargs): # pylint: disable=too-many-branches + timeout: float | None # timeout (httpx) if 'timeout' in kwargs: timeout = kwargs['timeout'] @@ -91,14 +96,17 @@ def request(method, url, **kwargs) -> SXNG_Response: with _record_http_time() as start_time: network = get_context_network() timeout = _get_timeout(start_time, kwargs) - future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop()) + future = asyncio.run_coroutine_threadsafe( + network.request(method, url, **kwargs), + get_loop(), + ) try: return future.result(timeout) except concurrent.futures.TimeoutError as e: raise httpx.TimeoutException('Timeout', request=None) from e -def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]: +def multi_requests(request_list: list["Request"]) -> list[httpx.Response | Exception]: """send multiple HTTP requests in parallel. Wait for all requests to finish.""" with _record_http_time() as start_time: # send the requests @@ -124,74 +132,74 @@ def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, return responses -class Request(NamedTuple): +class Request(t.NamedTuple): """Request description for the multi_requests function""" method: str url: str - kwargs: Dict[str, str] = {} + kwargs: dict[str, str] = {} @staticmethod - def get(url, **kwargs): + def get(url: str, **kwargs: t.Any): return Request('GET', url, kwargs) @staticmethod - def options(url, **kwargs): + def options(url: str, **kwargs: t.Any): return Request('OPTIONS', url, kwargs) @staticmethod - def head(url, **kwargs): + def head(url: str, **kwargs: t.Any): return Request('HEAD', url, kwargs) @staticmethod - def post(url, **kwargs): + def post(url: str, **kwargs: t.Any): return Request('POST', url, kwargs) @staticmethod - def put(url, **kwargs): + def put(url: str, **kwargs: t.Any): return Request('PUT', url, kwargs) @staticmethod - def patch(url, **kwargs): + def patch(url: str, **kwargs: t.Any): return Request('PATCH', url, kwargs) @staticmethod - def delete(url, **kwargs): + def delete(url: str, **kwargs: t.Any): return Request('DELETE', url, kwargs) -def get(url, **kwargs) -> SXNG_Response: +def get(url: str, **kwargs: t.Any) -> SXNG_Response: kwargs.setdefault('allow_redirects', True) return request('get', url, **kwargs) -def options(url, **kwargs) -> SXNG_Response: +def options(url: str, **kwargs: t.Any) -> SXNG_Response: kwargs.setdefault('allow_redirects', True) return request('options', url, **kwargs) -def head(url, **kwargs) -> SXNG_Response: +def head(url: str, **kwargs: t.Any) -> SXNG_Response: kwargs.setdefault('allow_redirects', False) return request('head', url, **kwargs) -def post(url, data=None, **kwargs) -> SXNG_Response: +def post(url: str, data=None, **kwargs: t.Any) -> SXNG_Response: return request('post', url, data=data, **kwargs) -def put(url, data=None, **kwargs) -> SXNG_Response: +def put(url: str, data=None, **kwargs: t.Any) -> SXNG_Response: return request('put', url, data=data, **kwargs) -def patch(url, data=None, **kwargs) -> SXNG_Response: +def patch(url: str, data=None, **kwargs: t.Any) -> SXNG_Response: return request('patch', url, data=data, **kwargs) -def delete(url, **kwargs) -> SXNG_Response: +def delete(url: str, **kwargs: t.Any) -> SXNG_Response: return request('delete', url, **kwargs) -async def stream_chunk_to_queue(network, queue, method, url, **kwargs): +async def stream_chunk_to_queue(network, queue, method: str, url: str, **kwargs: t.Any): try: async with await network.stream(method, url, **kwargs) as response: queue.put(response) @@ -217,7 +225,7 @@ async def stream_chunk_to_queue(network, queue, method, url, **kwargs): queue.put(None) -def _stream_generator(method, url, **kwargs): +def _stream_generator(method: str, url: str, **kwargs: t.Any): queue = SimpleQueue() network = get_context_network() future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop()) @@ -242,7 +250,7 @@ def _close_response_method(self): continue -def stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]: +def stream(method: str, url: str, **kwargs: t.Any) -> tuple[httpx.Response, Iterable[bytes]]: """Replace httpx.stream. Usage: diff --git a/searx/network/client.py b/searx/network/client.py index f35ba2d6e..8e69a9d46 100644 --- a/searx/network/client.py +++ b/searx/network/client.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # pylint: disable=missing-module-docstring, global-statement +import typing as t +from types import TracebackType + import asyncio import logging import random from ssl import SSLContext import threading -from typing import Any, Dict import httpx from httpx_socks import AsyncProxyTransport @@ -18,10 +20,13 @@ from searx import logger uvloop.install() +CertTypes = str | tuple[str, str] | tuple[str, str, str] +SslContextKeyType = tuple[str | None, CertTypes | None, bool, bool] logger = logger.getChild('searx.network.client') -LOOP = None -SSLCONTEXTS: Dict[Any, SSLContext] = {} +LOOP: asyncio.AbstractEventLoop = None # pyright: ignore[reportAssignmentType] + +SSLCONTEXTS: dict[SslContextKeyType, SSLContext] = {} def shuffle_ciphers(ssl_context: SSLContext): @@ -47,8 +52,10 @@ def shuffle_ciphers(ssl_context: SSLContext): ssl_context.set_ciphers(":".join(sc_list + c_list)) -def get_sslcontexts(proxy_url=None, cert=None, verify=True, trust_env=True): - key = (proxy_url, cert, verify, trust_env) +def get_sslcontexts( + proxy_url: str | None = None, cert: CertTypes | None = None, verify: bool = True, trust_env: bool = True +) -> SSLContext: + key: SslContextKeyType = (proxy_url, cert, verify, trust_env) if key not in SSLCONTEXTS: SSLCONTEXTS[key] = httpx.create_ssl_context(verify, cert, trust_env) shuffle_ciphers(SSLCONTEXTS[key]) @@ -68,12 +75,12 @@ class AsyncHTTPTransportNoHttp(httpx.AsyncHTTPTransport): For reference: https://github.com/encode/httpx/issues/2298 """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): # type: ignore # pylint: disable=super-init-not-called # this on purpose if the base class is not called pass - async def handle_async_request(self, request): + async def handle_async_request(self, request: httpx.Request): raise httpx.UnsupportedProtocol('HTTP protocol is disabled') async def aclose(self) -> None: @@ -84,9 +91,9 @@ class AsyncHTTPTransportNoHttp(httpx.AsyncHTTPTransport): async def __aexit__( self, - exc_type=None, - exc_value=None, - traceback=None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: pass @@ -97,18 +104,20 @@ class AsyncProxyTransportFixed(AsyncProxyTransport): Map python_socks exceptions to httpx.ProxyError exceptions """ - async def handle_async_request(self, request): + async def handle_async_request(self, request: httpx.Request): try: return await super().handle_async_request(request) except ProxyConnectionError as e: - raise httpx.ProxyError("ProxyConnectionError: " + e.strerror, request=request) from e + raise httpx.ProxyError("ProxyConnectionError: " + str(e.strerror), request=request) from e except ProxyTimeoutError as e: raise httpx.ProxyError("ProxyTimeoutError: " + e.args[0], request=request) from e except ProxyError as e: raise httpx.ProxyError("ProxyError: " + e.args[0], request=request) from e -def get_transport_for_socks_proxy(verify, http2, local_address, proxy_url, limit, retries): +def get_transport_for_socks_proxy( + verify: bool, http2: bool, local_address: str, proxy_url: str, limit: httpx.Limits, retries: int +): # support socks5h (requests compatibility): # https://requests.readthedocs.io/en/master/user/advanced/#socks # socks5:// hostname is resolved on client side @@ -120,7 +129,7 @@ def get_transport_for_socks_proxy(verify, http2, local_address, proxy_url, limit rdns = True proxy_type, proxy_host, proxy_port, proxy_username, proxy_password = parse_proxy_url(proxy_url) - verify = get_sslcontexts(proxy_url, None, verify, True) if verify is True else verify + _verify = get_sslcontexts(proxy_url, None, verify, True) if verify is True else verify return AsyncProxyTransportFixed( proxy_type=proxy_type, proxy_host=proxy_host, @@ -129,7 +138,7 @@ def get_transport_for_socks_proxy(verify, http2, local_address, proxy_url, limit password=proxy_password, rdns=rdns, loop=get_loop(), - verify=verify, + verify=_verify, http2=http2, local_address=local_address, limits=limit, @@ -137,14 +146,16 @@ def get_transport_for_socks_proxy(verify, http2, local_address, proxy_url, limit ) -def get_transport(verify, http2, local_address, proxy_url, limit, retries): - verify = get_sslcontexts(None, None, verify, True) if verify is True else verify +def get_transport( + verify: bool, http2: bool, local_address: str, proxy_url: str | None, limit: httpx.Limits, retries: int +): + _verify = get_sslcontexts(None, None, verify, True) if verify is True else verify return httpx.AsyncHTTPTransport( # pylint: disable=protected-access - verify=verify, + verify=_verify, http2=http2, limits=limit, - proxy=httpx._config.Proxy(proxy_url) if proxy_url else None, + proxy=httpx._config.Proxy(proxy_url) if proxy_url else None, # pyright: ignore[reportPrivateUsage] local_address=local_address, retries=retries, ) @@ -152,18 +163,18 @@ def get_transport(verify, http2, local_address, proxy_url, limit, retries): def new_client( # pylint: disable=too-many-arguments - enable_http, - verify, - enable_http2, - max_connections, - max_keepalive_connections, - keepalive_expiry, - proxies, - local_address, - retries, - max_redirects, - hook_log_response, -): + enable_http: bool, + verify: bool, + enable_http2: bool, + max_connections: int, + max_keepalive_connections: int, + keepalive_expiry: float, + proxies: dict[str, str], + local_address: str, + retries: int, + max_redirects: int, + hook_log_response: t.Callable[..., t.Any] | None, +) -> httpx.AsyncClient: limit = httpx.Limits( max_connections=max_connections, max_keepalive_connections=max_keepalive_connections, @@ -171,6 +182,7 @@ def new_client( ) # See https://www.python-httpx.org/advanced/#routing mounts = {} + mounts: None | (dict[str, t.Any | None]) = {} for pattern, proxy_url in proxies.items(): if not enable_http and pattern.startswith('http://'): continue @@ -198,7 +210,7 @@ def new_client( ) -def get_loop(): +def get_loop() -> asyncio.AbstractEventLoop: return LOOP diff --git a/searx/network/network.py b/searx/network/network.py index 8e2a1f12d..f52d9f87e 100644 --- a/searx/network/network.py +++ b/searx/network/network.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # pylint: disable=global-statement # pylint: disable=missing-module-docstring, missing-class-docstring -from __future__ import annotations +import typing as t +from collections.abc import Generator, AsyncIterator -import typing import atexit import asyncio import ipaddress from itertools import cycle -from typing import Dict import httpx @@ -20,7 +19,7 @@ from .raise_for_httperror import raise_for_httperror logger = logger.getChild('network') DEFAULT_NAME = '__DEFAULT__' -NETWORKS: Dict[str, 'Network'] = {} +NETWORKS: dict[str, "Network"] = {} # requests compatibility when reading proxy settings from settings.yml PROXY_PATTERN_MAPPING = { 'http': 'http://', @@ -38,6 +37,7 @@ PROXY_PATTERN_MAPPING = { ADDRESS_MAPPING = {'ipv4': '0.0.0.0', 'ipv6': '::'} +@t.final class Network: __slots__ = ( @@ -64,19 +64,19 @@ class Network: def __init__( # pylint: disable=too-many-arguments self, - enable_http=True, - verify=True, - enable_http2=False, - max_connections=None, - max_keepalive_connections=None, - keepalive_expiry=None, - proxies=None, - using_tor_proxy=False, - local_addresses=None, - retries=0, - retry_on_http_error=None, - max_redirects=30, - logger_name=None, + enable_http: bool = True, + verify: bool = True, + enable_http2: bool = False, + max_connections: int = None, # pyright: ignore[reportArgumentType] + max_keepalive_connections: int = None, # pyright: ignore[reportArgumentType] + keepalive_expiry: float = None, # pyright: ignore[reportArgumentType] + proxies: str | dict[str, str] | None = None, + using_tor_proxy: bool = False, + local_addresses: str | list[str] | None = None, + retries: int = 0, + retry_on_http_error: None = None, + max_redirects: int = 30, + logger_name: str = None, # pyright: ignore[reportArgumentType] ): self.enable_http = enable_http @@ -107,7 +107,7 @@ class Network: if self.proxies is not None and not isinstance(self.proxies, (str, dict)): raise ValueError('proxies type has to be str, dict or None') - def iter_ipaddresses(self): + def iter_ipaddresses(self) -> Generator[str]: local_addresses = self.local_addresses if not local_addresses: return @@ -130,7 +130,7 @@ class Network: if count == 0: yield None - def iter_proxies(self): + def iter_proxies(self) -> Generator[tuple[str, list[str]]]: if not self.proxies: return # https://www.python-httpx.org/compatibility/#proxy-keys @@ -138,13 +138,13 @@ class Network: yield 'all://', [self.proxies] else: for pattern, proxy_url in self.proxies.items(): - pattern = PROXY_PATTERN_MAPPING.get(pattern, pattern) + pattern: str = PROXY_PATTERN_MAPPING.get(pattern, pattern) if isinstance(proxy_url, str): proxy_url = [proxy_url] yield pattern, proxy_url - def get_proxy_cycles(self): - proxy_settings = {} + def get_proxy_cycles(self) -> Generator[tuple[tuple[str, str], ...], str, str]: # not sure type is correct + proxy_settings: dict[str, t.Any] = {} for pattern, proxy_urls in self.iter_proxies(): proxy_settings[pattern] = cycle(proxy_urls) while True: @@ -170,7 +170,10 @@ class Network: if isinstance(transport, AsyncHTTPTransportNoHttp): continue if getattr(transport, "_pool") and getattr( - transport._pool, "_rdns", False # pylint: disable=protected-access + # pylint: disable=protected-access + transport._pool, # type: ignore + "_rdns", + False, ): continue return False @@ -180,7 +183,7 @@ class Network: Network._TOR_CHECK_RESULT[proxies] = result return result - async def get_client(self, verify=None, max_redirects=None) -> httpx.AsyncClient: + async def get_client(self, verify: bool | None = None, max_redirects: int | None = None) -> httpx.AsyncClient: verify = self.verify if verify is None else verify max_redirects = self.max_redirects if max_redirects is None else max_redirects local_address = next(self._local_addresses_cycle) @@ -217,8 +220,8 @@ class Network: await asyncio.gather(*[close_client(client) for client in self._clients.values()], return_exceptions=False) @staticmethod - def extract_kwargs_clients(kwargs): - kwargs_clients = {} + def extract_kwargs_clients(kwargs: dict[str, t.Any]) -> dict[str, t.Any]: + kwargs_clients: dict[str, t.Any] = {} if 'verify' in kwargs: kwargs_clients['verify'] = kwargs.pop('verify') if 'max_redirects' in kwargs: @@ -236,9 +239,9 @@ class Network: del kwargs['raise_for_httperror'] return do_raise_for_httperror - def patch_response(self, response, do_raise_for_httperror) -> SXNG_Response: + def patch_response(self, response: httpx.Response | SXNG_Response, do_raise_for_httperror: bool) -> SXNG_Response: if isinstance(response, httpx.Response): - response = typing.cast(SXNG_Response, response) + response = t.cast(SXNG_Response, response) # requests compatibility (response is not streamed) # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses response.ok = not response.is_error @@ -252,7 +255,7 @@ class Network: raise return response - def is_valid_response(self, response): + def is_valid_response(self, response: SXNG_Response): # pylint: disable=too-many-boolean-expressions if ( (self.retry_on_http_error is True and 400 <= response.status_code <= 599) @@ -262,7 +265,9 @@ class Network: return False return True - async def call_client(self, stream, method, url, **kwargs) -> SXNG_Response: + async def call_client( + self, stream: bool, method: str, url: str, **kwargs: t.Any + ) -> AsyncIterator[SXNG_Response] | None: retries = self.retries was_disconnected = False do_raise_for_httperror = Network.extract_do_raise_for_httperror(kwargs) @@ -273,9 +278,9 @@ class Network: client.cookies = httpx.Cookies(cookies) try: if stream: - response = client.stream(method, url, **kwargs) + response = client.stream(method, url, **kwargs) # pyright: ignore[reportAny] else: - response = await client.request(method, url, **kwargs) + response = await client.request(method, url, **kwargs) # pyright: ignore[reportAny] if self.is_valid_response(response) or retries <= 0: return self.patch_response(response, do_raise_for_httperror) except httpx.RemoteProtocolError as e: @@ -293,10 +298,10 @@ class Network: raise e retries -= 1 - async def request(self, method, url, **kwargs): + async def request(self, method: str, url: str, **kwargs): return await self.call_client(False, method, url, **kwargs) - async def stream(self, method, url, **kwargs): + async def stream(self, method: str, url: str, **kwargs): return await self.call_client(True, method, url, **kwargs) @classmethod @@ -304,8 +309,8 @@ class Network: await asyncio.gather(*[network.aclose() for network in NETWORKS.values()], return_exceptions=False) -def get_network(name=None): - return NETWORKS.get(name or DEFAULT_NAME) +def get_network(name: str | None = None) -> "Network": + return NETWORKS.get(name or DEFAULT_NAME) # pyright: ignore[reportReturnType] def check_network_configuration(): @@ -326,7 +331,10 @@ def check_network_configuration(): raise RuntimeError("Invalid network configuration") -def initialize(settings_engines=None, settings_outgoing=None): +def initialize( + settings_engines: list[dict[str, t.Any]] = None, # pyright: ignore[reportArgumentType] + settings_outgoing: dict[str, t.Any] = None, # pyright: ignore[reportArgumentType] +) -> None: # pylint: disable=import-outside-toplevel) from searx.engines import engines from searx import settings @@ -338,7 +346,7 @@ def initialize(settings_engines=None, settings_outgoing=None): # default parameters for AsyncHTTPTransport # see https://github.com/encode/httpx/blob/e05a5372eb6172287458b37447c30f650047e1b8/httpx/_transports/default.py#L108-L121 # pylint: disable=line-too-long - default_params = { + default_params: dict[str, t.Any] = { 'enable_http': False, 'verify': settings_outgoing['verify'], 'enable_http2': settings_outgoing['enable_http2'], @@ -353,14 +361,14 @@ def initialize(settings_engines=None, settings_outgoing=None): 'retry_on_http_error': None, } - def new_network(params, logger_name=None): + def new_network(params: dict[str, t.Any], logger_name: str | None = None): nonlocal default_params result = {} - result.update(default_params) - result.update(params) + result.update(default_params) # pyright: ignore[reportUnknownMemberType] + result.update(params) # pyright: ignore[reportUnknownMemberType] if logger_name: result['logger_name'] = logger_name - return Network(**result) + return Network(**result) # type: ignore def iter_networks(): nonlocal settings_engines diff --git a/searx/network/raise_for_httperror.py b/searx/network/raise_for_httperror.py index abee2c78b..1a9e3d0d2 100644 --- a/searx/network/raise_for_httperror.py +++ b/searx/network/raise_for_httperror.py @@ -3,6 +3,7 @@ """ +import typing as t from searx.exceptions import ( SearxEngineCaptchaException, SearxEngineTooManyRequestsException, @@ -10,8 +11,11 @@ from searx.exceptions import ( ) from searx import get_setting +if t.TYPE_CHECKING: + from searx.extended_types import SXNG_Response -def is_cloudflare_challenge(resp): + +def is_cloudflare_challenge(resp: "SXNG_Response"): if resp.status_code in [429, 503]: if ('__cf_chl_jschl_tk__=' in resp.text) or ( '/cdn-cgi/challenge-platform/' in resp.text @@ -24,11 +28,11 @@ def is_cloudflare_challenge(resp): return False -def is_cloudflare_firewall(resp): +def is_cloudflare_firewall(resp: "SXNG_Response"): return resp.status_code == 403 and '<span class="cf-error-code">1020</span>' in resp.text -def raise_for_cloudflare_captcha(resp): +def raise_for_cloudflare_captcha(resp: "SXNG_Response"): if resp.headers.get('Server', '').startswith('cloudflare'): if is_cloudflare_challenge(resp): # https://support.cloudflare.com/hc/en-us/articles/200170136-Understanding-Cloudflare-Challenge-Passage-Captcha- @@ -44,19 +48,19 @@ def raise_for_cloudflare_captcha(resp): ) -def raise_for_recaptcha(resp): +def raise_for_recaptcha(resp: "SXNG_Response"): if resp.status_code == 503 and '"https://www.google.com/recaptcha/' in resp.text: raise SearxEngineCaptchaException( message='ReCAPTCHA', suspended_time=get_setting('search.suspended_times.recaptcha_SearxEngineCaptcha') ) -def raise_for_captcha(resp): +def raise_for_captcha(resp: "SXNG_Response"): raise_for_cloudflare_captcha(resp) raise_for_recaptcha(resp) -def raise_for_httperror(resp): +def raise_for_httperror(resp: "SXNG_Response") -> None: """Raise exception for an HTTP response is an error. Args: |