diff options
Diffstat (limited to 'searx/network/network.py')
| -rw-r--r-- | searx/network/network.py | 92 |
1 files changed, 50 insertions, 42 deletions
diff --git a/searx/network/network.py b/searx/network/network.py index 8e2a1f12d..f52d9f87e 100644 --- a/searx/network/network.py +++ b/searx/network/network.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # pylint: disable=global-statement # pylint: disable=missing-module-docstring, missing-class-docstring -from __future__ import annotations +import typing as t +from collections.abc import Generator, AsyncIterator -import typing import atexit import asyncio import ipaddress from itertools import cycle -from typing import Dict import httpx @@ -20,7 +19,7 @@ from .raise_for_httperror import raise_for_httperror logger = logger.getChild('network') DEFAULT_NAME = '__DEFAULT__' -NETWORKS: Dict[str, 'Network'] = {} +NETWORKS: dict[str, "Network"] = {} # requests compatibility when reading proxy settings from settings.yml PROXY_PATTERN_MAPPING = { 'http': 'http://', @@ -38,6 +37,7 @@ PROXY_PATTERN_MAPPING = { ADDRESS_MAPPING = {'ipv4': '0.0.0.0', 'ipv6': '::'} +@t.final class Network: __slots__ = ( @@ -64,19 +64,19 @@ class Network: def __init__( # pylint: disable=too-many-arguments self, - enable_http=True, - verify=True, - enable_http2=False, - max_connections=None, - max_keepalive_connections=None, - keepalive_expiry=None, - proxies=None, - using_tor_proxy=False, - local_addresses=None, - retries=0, - retry_on_http_error=None, - max_redirects=30, - logger_name=None, + enable_http: bool = True, + verify: bool = True, + enable_http2: bool = False, + max_connections: int = None, # pyright: ignore[reportArgumentType] + max_keepalive_connections: int = None, # pyright: ignore[reportArgumentType] + keepalive_expiry: float = None, # pyright: ignore[reportArgumentType] + proxies: str | dict[str, str] | None = None, + using_tor_proxy: bool = False, + local_addresses: str | list[str] | None = None, + retries: int = 0, + retry_on_http_error: None = None, + max_redirects: int = 30, + logger_name: str = None, # pyright: ignore[reportArgumentType] ): self.enable_http = enable_http @@ -107,7 +107,7 @@ class Network: if self.proxies is not None and not isinstance(self.proxies, (str, dict)): raise ValueError('proxies type has to be str, dict or None') - def iter_ipaddresses(self): + def iter_ipaddresses(self) -> Generator[str]: local_addresses = self.local_addresses if not local_addresses: return @@ -130,7 +130,7 @@ class Network: if count == 0: yield None - def iter_proxies(self): + def iter_proxies(self) -> Generator[tuple[str, list[str]]]: if not self.proxies: return # https://www.python-httpx.org/compatibility/#proxy-keys @@ -138,13 +138,13 @@ class Network: yield 'all://', [self.proxies] else: for pattern, proxy_url in self.proxies.items(): - pattern = PROXY_PATTERN_MAPPING.get(pattern, pattern) + pattern: str = PROXY_PATTERN_MAPPING.get(pattern, pattern) if isinstance(proxy_url, str): proxy_url = [proxy_url] yield pattern, proxy_url - def get_proxy_cycles(self): - proxy_settings = {} + def get_proxy_cycles(self) -> Generator[tuple[tuple[str, str], ...], str, str]: # not sure type is correct + proxy_settings: dict[str, t.Any] = {} for pattern, proxy_urls in self.iter_proxies(): proxy_settings[pattern] = cycle(proxy_urls) while True: @@ -170,7 +170,10 @@ class Network: if isinstance(transport, AsyncHTTPTransportNoHttp): continue if getattr(transport, "_pool") and getattr( - transport._pool, "_rdns", False # pylint: disable=protected-access + # pylint: disable=protected-access + transport._pool, # type: ignore + "_rdns", + False, ): continue return False @@ -180,7 +183,7 @@ class Network: Network._TOR_CHECK_RESULT[proxies] = result return result - async def get_client(self, verify=None, max_redirects=None) -> httpx.AsyncClient: + async def get_client(self, verify: bool | None = None, max_redirects: int | None = None) -> httpx.AsyncClient: verify = self.verify if verify is None else verify max_redirects = self.max_redirects if max_redirects is None else max_redirects local_address = next(self._local_addresses_cycle) @@ -217,8 +220,8 @@ class Network: await asyncio.gather(*[close_client(client) for client in self._clients.values()], return_exceptions=False) @staticmethod - def extract_kwargs_clients(kwargs): - kwargs_clients = {} + def extract_kwargs_clients(kwargs: dict[str, t.Any]) -> dict[str, t.Any]: + kwargs_clients: dict[str, t.Any] = {} if 'verify' in kwargs: kwargs_clients['verify'] = kwargs.pop('verify') if 'max_redirects' in kwargs: @@ -236,9 +239,9 @@ class Network: del kwargs['raise_for_httperror'] return do_raise_for_httperror - def patch_response(self, response, do_raise_for_httperror) -> SXNG_Response: + def patch_response(self, response: httpx.Response | SXNG_Response, do_raise_for_httperror: bool) -> SXNG_Response: if isinstance(response, httpx.Response): - response = typing.cast(SXNG_Response, response) + response = t.cast(SXNG_Response, response) # requests compatibility (response is not streamed) # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses response.ok = not response.is_error @@ -252,7 +255,7 @@ class Network: raise return response - def is_valid_response(self, response): + def is_valid_response(self, response: SXNG_Response): # pylint: disable=too-many-boolean-expressions if ( (self.retry_on_http_error is True and 400 <= response.status_code <= 599) @@ -262,7 +265,9 @@ class Network: return False return True - async def call_client(self, stream, method, url, **kwargs) -> SXNG_Response: + async def call_client( + self, stream: bool, method: str, url: str, **kwargs: t.Any + ) -> AsyncIterator[SXNG_Response] | None: retries = self.retries was_disconnected = False do_raise_for_httperror = Network.extract_do_raise_for_httperror(kwargs) @@ -273,9 +278,9 @@ class Network: client.cookies = httpx.Cookies(cookies) try: if stream: - response = client.stream(method, url, **kwargs) + response = client.stream(method, url, **kwargs) # pyright: ignore[reportAny] else: - response = await client.request(method, url, **kwargs) + response = await client.request(method, url, **kwargs) # pyright: ignore[reportAny] if self.is_valid_response(response) or retries <= 0: return self.patch_response(response, do_raise_for_httperror) except httpx.RemoteProtocolError as e: @@ -293,10 +298,10 @@ class Network: raise e retries -= 1 - async def request(self, method, url, **kwargs): + async def request(self, method: str, url: str, **kwargs): return await self.call_client(False, method, url, **kwargs) - async def stream(self, method, url, **kwargs): + async def stream(self, method: str, url: str, **kwargs): return await self.call_client(True, method, url, **kwargs) @classmethod @@ -304,8 +309,8 @@ class Network: await asyncio.gather(*[network.aclose() for network in NETWORKS.values()], return_exceptions=False) -def get_network(name=None): - return NETWORKS.get(name or DEFAULT_NAME) +def get_network(name: str | None = None) -> "Network": + return NETWORKS.get(name or DEFAULT_NAME) # pyright: ignore[reportReturnType] def check_network_configuration(): @@ -326,7 +331,10 @@ def check_network_configuration(): raise RuntimeError("Invalid network configuration") -def initialize(settings_engines=None, settings_outgoing=None): +def initialize( + settings_engines: list[dict[str, t.Any]] = None, # pyright: ignore[reportArgumentType] + settings_outgoing: dict[str, t.Any] = None, # pyright: ignore[reportArgumentType] +) -> None: # pylint: disable=import-outside-toplevel) from searx.engines import engines from searx import settings @@ -338,7 +346,7 @@ def initialize(settings_engines=None, settings_outgoing=None): # default parameters for AsyncHTTPTransport # see https://github.com/encode/httpx/blob/e05a5372eb6172287458b37447c30f650047e1b8/httpx/_transports/default.py#L108-L121 # pylint: disable=line-too-long - default_params = { + default_params: dict[str, t.Any] = { 'enable_http': False, 'verify': settings_outgoing['verify'], 'enable_http2': settings_outgoing['enable_http2'], @@ -353,14 +361,14 @@ def initialize(settings_engines=None, settings_outgoing=None): 'retry_on_http_error': None, } - def new_network(params, logger_name=None): + def new_network(params: dict[str, t.Any], logger_name: str | None = None): nonlocal default_params result = {} - result.update(default_params) - result.update(params) + result.update(default_params) # pyright: ignore[reportUnknownMemberType] + result.update(params) # pyright: ignore[reportUnknownMemberType] if logger_name: result['logger_name'] = logger_name - return Network(**result) + return Network(**result) # type: ignore def iter_networks(): nonlocal settings_engines |