summaryrefslogtreecommitdiff
path: root/searx/network/network.py
diff options
context:
space:
mode:
Diffstat (limited to 'searx/network/network.py')
-rw-r--r--searx/network/network.py92
1 files changed, 50 insertions, 42 deletions
diff --git a/searx/network/network.py b/searx/network/network.py
index 8e2a1f12d..f52d9f87e 100644
--- a/searx/network/network.py
+++ b/searx/network/network.py
@@ -1,14 +1,13 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# pylint: disable=global-statement
# pylint: disable=missing-module-docstring, missing-class-docstring
-from __future__ import annotations
+import typing as t
+from collections.abc import Generator, AsyncIterator
-import typing
import atexit
import asyncio
import ipaddress
from itertools import cycle
-from typing import Dict
import httpx
@@ -20,7 +19,7 @@ from .raise_for_httperror import raise_for_httperror
logger = logger.getChild('network')
DEFAULT_NAME = '__DEFAULT__'
-NETWORKS: Dict[str, 'Network'] = {}
+NETWORKS: dict[str, "Network"] = {}
# requests compatibility when reading proxy settings from settings.yml
PROXY_PATTERN_MAPPING = {
'http': 'http://',
@@ -38,6 +37,7 @@ PROXY_PATTERN_MAPPING = {
ADDRESS_MAPPING = {'ipv4': '0.0.0.0', 'ipv6': '::'}
+@t.final
class Network:
__slots__ = (
@@ -64,19 +64,19 @@ class Network:
def __init__(
# pylint: disable=too-many-arguments
self,
- enable_http=True,
- verify=True,
- enable_http2=False,
- max_connections=None,
- max_keepalive_connections=None,
- keepalive_expiry=None,
- proxies=None,
- using_tor_proxy=False,
- local_addresses=None,
- retries=0,
- retry_on_http_error=None,
- max_redirects=30,
- logger_name=None,
+ enable_http: bool = True,
+ verify: bool = True,
+ enable_http2: bool = False,
+ max_connections: int = None, # pyright: ignore[reportArgumentType]
+ max_keepalive_connections: int = None, # pyright: ignore[reportArgumentType]
+ keepalive_expiry: float = None, # pyright: ignore[reportArgumentType]
+ proxies: str | dict[str, str] | None = None,
+ using_tor_proxy: bool = False,
+ local_addresses: str | list[str] | None = None,
+ retries: int = 0,
+ retry_on_http_error: None = None,
+ max_redirects: int = 30,
+ logger_name: str = None, # pyright: ignore[reportArgumentType]
):
self.enable_http = enable_http
@@ -107,7 +107,7 @@ class Network:
if self.proxies is not None and not isinstance(self.proxies, (str, dict)):
raise ValueError('proxies type has to be str, dict or None')
- def iter_ipaddresses(self):
+ def iter_ipaddresses(self) -> Generator[str]:
local_addresses = self.local_addresses
if not local_addresses:
return
@@ -130,7 +130,7 @@ class Network:
if count == 0:
yield None
- def iter_proxies(self):
+ def iter_proxies(self) -> Generator[tuple[str, list[str]]]:
if not self.proxies:
return
# https://www.python-httpx.org/compatibility/#proxy-keys
@@ -138,13 +138,13 @@ class Network:
yield 'all://', [self.proxies]
else:
for pattern, proxy_url in self.proxies.items():
- pattern = PROXY_PATTERN_MAPPING.get(pattern, pattern)
+ pattern: str = PROXY_PATTERN_MAPPING.get(pattern, pattern)
if isinstance(proxy_url, str):
proxy_url = [proxy_url]
yield pattern, proxy_url
- def get_proxy_cycles(self):
- proxy_settings = {}
+ def get_proxy_cycles(self) -> Generator[tuple[tuple[str, str], ...], str, str]: # not sure type is correct
+ proxy_settings: dict[str, t.Any] = {}
for pattern, proxy_urls in self.iter_proxies():
proxy_settings[pattern] = cycle(proxy_urls)
while True:
@@ -170,7 +170,10 @@ class Network:
if isinstance(transport, AsyncHTTPTransportNoHttp):
continue
if getattr(transport, "_pool") and getattr(
- transport._pool, "_rdns", False # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ transport._pool, # type: ignore
+ "_rdns",
+ False,
):
continue
return False
@@ -180,7 +183,7 @@ class Network:
Network._TOR_CHECK_RESULT[proxies] = result
return result
- async def get_client(self, verify=None, max_redirects=None) -> httpx.AsyncClient:
+ async def get_client(self, verify: bool | None = None, max_redirects: int | None = None) -> httpx.AsyncClient:
verify = self.verify if verify is None else verify
max_redirects = self.max_redirects if max_redirects is None else max_redirects
local_address = next(self._local_addresses_cycle)
@@ -217,8 +220,8 @@ class Network:
await asyncio.gather(*[close_client(client) for client in self._clients.values()], return_exceptions=False)
@staticmethod
- def extract_kwargs_clients(kwargs):
- kwargs_clients = {}
+ def extract_kwargs_clients(kwargs: dict[str, t.Any]) -> dict[str, t.Any]:
+ kwargs_clients: dict[str, t.Any] = {}
if 'verify' in kwargs:
kwargs_clients['verify'] = kwargs.pop('verify')
if 'max_redirects' in kwargs:
@@ -236,9 +239,9 @@ class Network:
del kwargs['raise_for_httperror']
return do_raise_for_httperror
- def patch_response(self, response, do_raise_for_httperror) -> SXNG_Response:
+ def patch_response(self, response: httpx.Response | SXNG_Response, do_raise_for_httperror: bool) -> SXNG_Response:
if isinstance(response, httpx.Response):
- response = typing.cast(SXNG_Response, response)
+ response = t.cast(SXNG_Response, response)
# requests compatibility (response is not streamed)
# see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
response.ok = not response.is_error
@@ -252,7 +255,7 @@ class Network:
raise
return response
- def is_valid_response(self, response):
+ def is_valid_response(self, response: SXNG_Response):
# pylint: disable=too-many-boolean-expressions
if (
(self.retry_on_http_error is True and 400 <= response.status_code <= 599)
@@ -262,7 +265,9 @@ class Network:
return False
return True
- async def call_client(self, stream, method, url, **kwargs) -> SXNG_Response:
+ async def call_client(
+ self, stream: bool, method: str, url: str, **kwargs: t.Any
+ ) -> AsyncIterator[SXNG_Response] | None:
retries = self.retries
was_disconnected = False
do_raise_for_httperror = Network.extract_do_raise_for_httperror(kwargs)
@@ -273,9 +278,9 @@ class Network:
client.cookies = httpx.Cookies(cookies)
try:
if stream:
- response = client.stream(method, url, **kwargs)
+ response = client.stream(method, url, **kwargs) # pyright: ignore[reportAny]
else:
- response = await client.request(method, url, **kwargs)
+ response = await client.request(method, url, **kwargs) # pyright: ignore[reportAny]
if self.is_valid_response(response) or retries <= 0:
return self.patch_response(response, do_raise_for_httperror)
except httpx.RemoteProtocolError as e:
@@ -293,10 +298,10 @@ class Network:
raise e
retries -= 1
- async def request(self, method, url, **kwargs):
+ async def request(self, method: str, url: str, **kwargs):
return await self.call_client(False, method, url, **kwargs)
- async def stream(self, method, url, **kwargs):
+ async def stream(self, method: str, url: str, **kwargs):
return await self.call_client(True, method, url, **kwargs)
@classmethod
@@ -304,8 +309,8 @@ class Network:
await asyncio.gather(*[network.aclose() for network in NETWORKS.values()], return_exceptions=False)
-def get_network(name=None):
- return NETWORKS.get(name or DEFAULT_NAME)
+def get_network(name: str | None = None) -> "Network":
+ return NETWORKS.get(name or DEFAULT_NAME) # pyright: ignore[reportReturnType]
def check_network_configuration():
@@ -326,7 +331,10 @@ def check_network_configuration():
raise RuntimeError("Invalid network configuration")
-def initialize(settings_engines=None, settings_outgoing=None):
+def initialize(
+ settings_engines: list[dict[str, t.Any]] = None, # pyright: ignore[reportArgumentType]
+ settings_outgoing: dict[str, t.Any] = None, # pyright: ignore[reportArgumentType]
+) -> None:
# pylint: disable=import-outside-toplevel)
from searx.engines import engines
from searx import settings
@@ -338,7 +346,7 @@ def initialize(settings_engines=None, settings_outgoing=None):
# default parameters for AsyncHTTPTransport
# see https://github.com/encode/httpx/blob/e05a5372eb6172287458b37447c30f650047e1b8/httpx/_transports/default.py#L108-L121 # pylint: disable=line-too-long
- default_params = {
+ default_params: dict[str, t.Any] = {
'enable_http': False,
'verify': settings_outgoing['verify'],
'enable_http2': settings_outgoing['enable_http2'],
@@ -353,14 +361,14 @@ def initialize(settings_engines=None, settings_outgoing=None):
'retry_on_http_error': None,
}
- def new_network(params, logger_name=None):
+ def new_network(params: dict[str, t.Any], logger_name: str | None = None):
nonlocal default_params
result = {}
- result.update(default_params)
- result.update(params)
+ result.update(default_params) # pyright: ignore[reportUnknownMemberType]
+ result.update(params) # pyright: ignore[reportUnknownMemberType]
if logger_name:
result['logger_name'] = logger_name
- return Network(**result)
+ return Network(**result) # type: ignore
def iter_networks():
nonlocal settings_engines