second commit

This commit is contained in:
2024-12-27 22:31:23 +09:00
parent 2353324570
commit 10a0f110ca
8819 changed files with 1307198 additions and 28 deletions

View File

@ -0,0 +1,11 @@
from __future__ import annotations
import warnings
warnings.warn( # deprecated in 14.0 - 2024-11-09
"websockets.legacy is deprecated; "
"see https://websockets.readthedocs.io/en/stable/howto/upgrade.html "
"for upgrade instructions",
DeprecationWarning,
)

View File

@ -0,0 +1,190 @@
from __future__ import annotations
import functools
import hmac
import http
from collections.abc import Awaitable, Iterable
from typing import Any, Callable, cast
from ..datastructures import Headers
from ..exceptions import InvalidHeader
from ..headers import build_www_authenticate_basic, parse_authorization_basic
from .server import HTTPResponse, WebSocketServerProtocol
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
Credentials = tuple[str, str]
def is_credentials(value: Any) -> bool:
try:
username, password = value
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
"""
WebSocket server protocol that enforces HTTP Basic Auth.
"""
realm: str = ""
"""
Scope of protection.
If provided, it should contain only ASCII characters because the
encoding of non-ASCII characters is undefined.
"""
username: str | None = None
"""Username of the authenticated user."""
def __init__(
self,
*args: Any,
realm: str | None = None,
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
**kwargs: Any,
) -> None:
if realm is not None:
self.realm = realm # shadow class attribute
self._check_credentials = check_credentials
super().__init__(*args, **kwargs)
async def check_credentials(self, username: str, password: str) -> bool:
"""
Check whether credentials are authorized.
This coroutine may be overridden in a subclass, for example to
authenticate against a database or an external service.
Args:
username: HTTP Basic Auth username.
password: HTTP Basic Auth password.
Returns:
:obj:`True` if the handshake should continue;
:obj:`False` if it should fail with an HTTP 401 error.
"""
if self._check_credentials is not None:
return await self._check_credentials(username, password)
return False
async def process_request(
self,
path: str,
request_headers: Headers,
) -> HTTPResponse | None:
"""
Check HTTP Basic Auth and return an HTTP 401 response if needed.
"""
try:
authorization = request_headers["Authorization"]
except KeyError:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Missing credentials\n",
)
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Unsupported credentials\n",
)
if not await self.check_credentials(username, password):
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Invalid credentials\n",
)
self.username = username
return await super().process_request(path, request_headers)
def basic_auth_protocol_factory(
realm: str | None = None,
credentials: Credentials | Iterable[Credentials] | None = None,
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
"""
Protocol factory that enforces HTTP Basic Auth.
:func:`basic_auth_protocol_factory` is designed to integrate with
:func:`~websockets.legacy.server.serve` like this::
serve(
...,
create_protocol=basic_auth_protocol_factory(
realm="my dev server",
credentials=("hello", "iloveyou"),
)
)
Args:
realm: Scope of protection. It should contain only ASCII characters
because the encoding of non-ASCII characters is undefined.
Refer to section 2.2 of :rfc:`7235` for details.
credentials: Hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
check_credentials: Coroutine that verifies credentials.
It receives ``username`` and ``password`` arguments
and returns a :class:`bool`. One of ``credentials`` or
``check_credentials`` must be provided but not both.
create_protocol: Factory that creates the protocol. By default, this
is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
by a subclass.
Raises:
TypeError: If the ``credentials`` or ``check_credentials`` argument is
wrong.
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
credentials_list = [cast(Credentials, credentials)]
elif isinstance(credentials, Iterable):
credentials_list = list(cast(Iterable[Credentials], credentials))
if not all(is_credentials(item) for item in credentials_list):
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
credentials_dict = dict(credentials_list)
async def check_credentials(username: str, password: str) -> bool:
try:
expected_password = credentials_dict[username]
except KeyError:
return False
return hmac.compare_digest(expected_password, password)
if create_protocol is None:
create_protocol = BasicAuthWebSocketServerProtocol
# Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
# Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc]
create_protocol = cast(
Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
)
return functools.partial(
create_protocol,
realm=realm,
check_credentials=check_credentials,
)

View File

@ -0,0 +1,704 @@
from __future__ import annotations
import asyncio
import functools
import logging
import os
import random
import traceback
import urllib.parse
import warnings
from collections.abc import AsyncIterator, Generator, Sequence
from types import TracebackType
from typing import Any, Callable, cast
from ..asyncio.compatibility import asyncio_timeout
from ..datastructures import Headers, HeadersLike
from ..exceptions import (
InvalidHeader,
InvalidHeaderValue,
NegotiationError,
SecurityError,
)
from ..extensions import ClientExtensionFactory, Extension
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import (
build_authorization_basic,
build_extension,
build_host,
build_subprotocol,
parse_extension,
parse_subprotocol,
validate_subprotocols,
)
from ..http11 import USER_AGENT
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
from ..uri import WebSocketURI, parse_uri
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
from .handshake import build_request, check_response
from .http import read_response
from .protocol import WebSocketCommonProtocol
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
class WebSocketClientProtocol(WebSocketCommonProtocol):
"""
WebSocket client connection.
:class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
coroutines for receiving and sending messages.
It supports asynchronous iteration to receive messages::
async for message in websocket:
await process(message)
The iterator exits normally when the connection is closed with close code
1000 (OK) or 1001 (going away) or without a close code. It raises
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
is closed with any other code.
See :func:`connect` for the documentation of ``logger``, ``origin``,
``extensions``, ``subprotocols``, ``extra_headers``, and
``user_agent_header``.
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
"""
is_client = True
side = "client"
def __init__(
self,
*,
logger: LoggerLike | None = None,
origin: Origin | None = None,
extensions: Sequence[ClientExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
extra_headers: HeadersLike | None = None,
user_agent_header: str | None = USER_AGENT,
**kwargs: Any,
) -> None:
if logger is None:
logger = logging.getLogger("websockets.client")
super().__init__(logger=logger, **kwargs)
self.origin = origin
self.available_extensions = extensions
self.available_subprotocols = subprotocols
self.extra_headers = extra_headers
self.user_agent_header = user_agent_header
def write_http_request(self, path: str, headers: Headers) -> None:
"""
Write request line and headers to the HTTP request.
"""
self.path = path
self.request_headers = headers
if self.debug:
self.logger.debug("> GET %s HTTP/1.1", path)
for key, value in headers.raw_items():
self.logger.debug("> %s: %s", key, value)
# Since the path and headers only contain ASCII characters,
# we can keep this simple.
request = f"GET {path} HTTP/1.1\r\n"
request += str(headers)
self.transport.write(request.encode())
async def read_http_response(self) -> tuple[int, Headers]:
"""
Read status line and headers from the HTTP response.
If the response contains a body, it may be read from ``self.reader``
after this coroutine returns.
Raises:
InvalidMessage: If the HTTP message is malformed or isn't an
HTTP/1.1 GET response.
"""
try:
status_code, reason, headers = await read_response(self.reader)
except Exception as exc:
raise InvalidMessage("did not receive a valid HTTP response") from exc
if self.debug:
self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
for key, value in headers.raw_items():
self.logger.debug("< %s: %s", key, value)
self.response_headers = headers
return status_code, self.response_headers
@staticmethod
def process_extensions(
headers: Headers,
available_extensions: Sequence[ClientExtensionFactory] | None,
) -> list[Extension]:
"""
Handle the Sec-WebSocket-Extensions HTTP response header.
Check that each extension is supported, as well as its parameters.
Return the list of accepted extensions.
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
connection.
:rfc:`6455` leaves the rules up to the specification of each
:extension.
To provide this level of flexibility, for each extension accepted by
the server, we check for a match with each extension available in the
client configuration. If no match is found, an exception is raised.
If several variants of the same extension are accepted by the server,
it may be configured several times, which won't make sense in general.
Extensions must implement their own requirements. For this purpose,
the list of previously accepted extensions is provided.
Other requirements, for example related to mandatory extensions or the
order of extensions, may be implemented by overriding this method.
"""
accepted_extensions: list[Extension] = []
header_values = headers.get_all("Sec-WebSocket-Extensions")
if header_values:
if available_extensions is None:
raise NegotiationError("no extensions supported")
parsed_header_values: list[ExtensionHeader] = sum(
[parse_extension(header_value) for header_value in header_values], []
)
for name, response_params in parsed_header_values:
for extension_factory in available_extensions:
# Skip non-matching extensions based on their name.
if extension_factory.name != name:
continue
# Skip non-matching extensions based on their params.
try:
extension = extension_factory.process_response_params(
response_params, accepted_extensions
)
except NegotiationError:
continue
# Add matching extension to the final list.
accepted_extensions.append(extension)
# Break out of the loop once we have a match.
break
# If we didn't break from the loop, no extension in our list
# matched what the server sent. Fail the connection.
else:
raise NegotiationError(
f"Unsupported extension: "
f"name = {name}, params = {response_params}"
)
return accepted_extensions
@staticmethod
def process_subprotocol(
headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
) -> Subprotocol | None:
"""
Handle the Sec-WebSocket-Protocol HTTP response header.
Check that it contains exactly one supported subprotocol.
Return the selected subprotocol.
"""
subprotocol: Subprotocol | None = None
header_values = headers.get_all("Sec-WebSocket-Protocol")
if header_values:
if available_subprotocols is None:
raise NegotiationError("no subprotocols supported")
parsed_header_values: Sequence[Subprotocol] = sum(
[parse_subprotocol(header_value) for header_value in header_values], []
)
if len(parsed_header_values) > 1:
raise InvalidHeaderValue(
"Sec-WebSocket-Protocol",
f"multiple values: {', '.join(parsed_header_values)}",
)
subprotocol = parsed_header_values[0]
if subprotocol not in available_subprotocols:
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
return subprotocol
async def handshake(
self,
wsuri: WebSocketURI,
origin: Origin | None = None,
available_extensions: Sequence[ClientExtensionFactory] | None = None,
available_subprotocols: Sequence[Subprotocol] | None = None,
extra_headers: HeadersLike | None = None,
) -> None:
"""
Perform the client side of the opening handshake.
Args:
wsuri: URI of the WebSocket server.
origin: Value of the ``Origin`` header.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
preference.
extra_headers: Arbitrary HTTP headers to add to the handshake request.
Raises:
InvalidHandshake: If the handshake fails.
"""
request_headers = Headers()
request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
if wsuri.user_info:
request_headers["Authorization"] = build_authorization_basic(
*wsuri.user_info
)
if origin is not None:
request_headers["Origin"] = origin
key = build_request(request_headers)
if available_extensions is not None:
extensions_header = build_extension(
[
(extension_factory.name, extension_factory.get_request_params())
for extension_factory in available_extensions
]
)
request_headers["Sec-WebSocket-Extensions"] = extensions_header
if available_subprotocols is not None:
protocol_header = build_subprotocol(available_subprotocols)
request_headers["Sec-WebSocket-Protocol"] = protocol_header
if self.extra_headers is not None:
request_headers.update(self.extra_headers)
if self.user_agent_header:
request_headers.setdefault("User-Agent", self.user_agent_header)
self.write_http_request(wsuri.resource_name, request_headers)
status_code, response_headers = await self.read_http_response()
if status_code in (301, 302, 303, 307, 308):
if "Location" not in response_headers:
raise InvalidHeader("Location")
raise RedirectHandshake(response_headers["Location"])
elif status_code != 101:
raise InvalidStatusCode(status_code, response_headers)
check_response(response_headers, key)
self.extensions = self.process_extensions(
response_headers, available_extensions
)
self.subprotocol = self.process_subprotocol(
response_headers, available_subprotocols
)
self.connection_open()
class Connect:
"""
Connect to the WebSocket server at ``uri``.
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
can then be used to send and receive messages.
:func:`connect` can be used as a asynchronous context manager::
async with connect(...) as websocket:
...
The connection is closed automatically when exiting the context.
:func:`connect` can be used as an infinite asynchronous iterator to
reconnect automatically on errors::
async for websocket in connect(...):
try:
...
except websockets.exceptions.ConnectionClosed:
continue
The connection is closed automatically after each iteration of the loop.
If an error occurs while establishing the connection, :func:`connect`
retries with exponential backoff. The backoff delay starts at three
seconds and increases up to one minute.
If an error occurs in the body of the loop, you can handle the exception
and :func:`connect` will reconnect with the next iteration; or you can
let the exception bubble up and break out of the loop. This lets you
decide which errors trigger a reconnection and which errors are fatal.
Args:
uri: URI of the WebSocket server.
create_protocol: Factory for the :class:`asyncio.Protocol` managing
the connection. It defaults to :class:`WebSocketClientProtocol`.
Set it to a wrapper or a subclass to customize connection handling.
logger: Logger for this client.
It defaults to ``logging.getLogger("websockets.client")``.
See the :doc:`logging guide <../../topics/logging>` for details.
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
origin: Value of the ``Origin`` header, for servers that require it.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
preference.
extra_headers: Arbitrary HTTP headers to add to the handshake request.
user_agent_header: Value of the ``User-Agent`` request header.
It defaults to ``"Python/x.y.z websockets/X.Y"``.
Setting it to :obj:`None` removes the header.
open_timeout: Timeout for opening the connection in seconds.
:obj:`None` disables the timeout.
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
Any other keyword arguments are passed the event loop's
:meth:`~asyncio.loop.create_connection` method.
For example:
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
provided, a TLS context is created
with :func:`~ssl.create_default_context`.
* You can set ``host`` and ``port`` to connect to a different host and
port from those found in ``uri``. This only changes the destination of
the TCP connection. The host name from ``uri`` is still used in the TLS
handshake for secure connections and in the ``Host`` header.
Raises:
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
OSError: If the TCP connection fails.
InvalidHandshake: If the opening handshake fails.
~asyncio.TimeoutError: If the opening handshake times out.
"""
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
def __init__(
self,
uri: str,
*,
create_protocol: Callable[..., WebSocketClientProtocol] | None = None,
logger: LoggerLike | None = None,
compression: str | None = "deflate",
origin: Origin | None = None,
extensions: Sequence[ClientExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
extra_headers: HeadersLike | None = None,
user_agent_header: str | None = USER_AGENT,
open_timeout: float | None = 10,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = None,
max_size: int | None = 2**20,
max_queue: int | None = 2**5,
read_limit: int = 2**16,
write_limit: int = 2**16,
**kwargs: Any,
) -> None:
# Backwards compatibility: close_timeout used to be called timeout.
timeout: float | None = kwargs.pop("timeout", None)
if timeout is None:
timeout = 10
else:
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
# If both are specified, timeout is ignored.
if close_timeout is None:
close_timeout = timeout
# Backwards compatibility: create_protocol used to be called klass.
klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None)
if klass is None:
klass = WebSocketClientProtocol
else:
warnings.warn("rename klass to create_protocol", DeprecationWarning)
# If both are specified, klass is ignored.
if create_protocol is None:
create_protocol = klass
# Backwards compatibility: recv() used to return None on closed connections
legacy_recv: bool = kwargs.pop("legacy_recv", False)
# Backwards compatibility: the loop parameter used to be supported.
_loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None)
if _loop is None:
loop = asyncio.get_event_loop()
else:
loop = _loop
warnings.warn("remove loop argument", DeprecationWarning)
wsuri = parse_uri(uri)
if wsuri.secure:
kwargs.setdefault("ssl", True)
elif kwargs.get("ssl") is not None:
raise ValueError(
"connect() received a ssl argument for a ws:// URI, "
"use a wss:// URI to enable TLS"
)
if compression == "deflate":
extensions = enable_client_permessage_deflate(extensions)
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
if subprotocols is not None:
validate_subprotocols(subprotocols)
# Help mypy and avoid this error: "type[WebSocketClientProtocol] |
# Callable[..., WebSocketClientProtocol]" not callable [misc]
create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol)
factory = functools.partial(
create_protocol,
logger=logger,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
extra_headers=extra_headers,
user_agent_header=user_agent_header,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_size=max_size,
max_queue=max_queue,
read_limit=read_limit,
write_limit=write_limit,
host=wsuri.host,
port=wsuri.port,
secure=wsuri.secure,
legacy_recv=legacy_recv,
loop=_loop,
)
if kwargs.pop("unix", False):
path: str | None = kwargs.pop("path", None)
create_connection = functools.partial(
loop.create_unix_connection, factory, path, **kwargs
)
else:
host: str | None
port: int | None
if kwargs.get("sock") is None:
host, port = wsuri.host, wsuri.port
else:
# If sock is given, host and port shouldn't be specified.
host, port = None, None
if kwargs.get("ssl"):
kwargs.setdefault("server_hostname", wsuri.host)
# If host and port are given, override values from the URI.
host = kwargs.pop("host", host)
port = kwargs.pop("port", port)
create_connection = functools.partial(
loop.create_connection, factory, host, port, **kwargs
)
self.open_timeout = open_timeout
if logger is None:
logger = logging.getLogger("websockets.client")
self.logger = logger
# This is a coroutine function.
self._create_connection = create_connection
self._uri = uri
self._wsuri = wsuri
def handle_redirect(self, uri: str) -> None:
# Update the state of this instance to connect to a new URI.
old_uri = self._uri
old_wsuri = self._wsuri
new_uri = urllib.parse.urljoin(old_uri, uri)
new_wsuri = parse_uri(new_uri)
# Forbid TLS downgrade.
if old_wsuri.secure and not new_wsuri.secure:
raise SecurityError("redirect from WSS to WS")
same_origin = (
old_wsuri.secure == new_wsuri.secure
and old_wsuri.host == new_wsuri.host
and old_wsuri.port == new_wsuri.port
)
# Rewrite secure, host, and port for cross-origin redirects.
# This preserves connection overrides with the host and port
# arguments if the redirect points to the same host and port.
if not same_origin:
factory = self._create_connection.args[0]
# Support TLS upgrade.
if not old_wsuri.secure and new_wsuri.secure:
factory.keywords["secure"] = True
self._create_connection.keywords.setdefault("ssl", True)
# Replace secure, host, and port arguments of the protocol factory.
factory = functools.partial(
factory.func,
*factory.args,
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
)
# Replace secure, host, and port arguments of create_connection.
self._create_connection = functools.partial(
self._create_connection.func,
*(factory, new_wsuri.host, new_wsuri.port),
**self._create_connection.keywords,
)
# Set the new WebSocket URI. This suffices for same-origin redirects.
self._uri = new_uri
self._wsuri = new_wsuri
# async for ... in connect(...):
BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
while True:
try:
async with self as protocol:
yield protocol
except Exception as exc:
# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
if backoff_delay == self.BACKOFF_MIN:
initial_delay = random.random() * self.BACKOFF_INITIAL
self.logger.info(
"connect failed; reconnecting in %.1f seconds: %s",
initial_delay,
# Remove first argument when dropping Python 3.9.
traceback.format_exception_only(type(exc), exc)[0].strip(),
)
await asyncio.sleep(initial_delay)
else:
self.logger.info(
"connect failed again; retrying in %d seconds: %s",
int(backoff_delay),
# Remove first argument when dropping Python 3.9.
traceback.format_exception_only(type(exc), exc)[0].strip(),
)
await asyncio.sleep(int(backoff_delay))
# Increase delay with truncated exponential backoff.
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
continue
else:
# Connection succeeded - reset backoff delay
backoff_delay = self.BACKOFF_MIN
# async with connect(...) as ...:
async def __aenter__(self) -> WebSocketClientProtocol:
return await self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self.protocol.close()
# ... = await connect(...)
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
# Create a suitable iterator by calling __await__ on a coroutine.
return self.__await_impl__().__await__()
async def __await_impl__(self) -> WebSocketClientProtocol:
async with asyncio_timeout(self.open_timeout):
for _redirects in range(self.MAX_REDIRECTS_ALLOWED):
_transport, protocol = await self._create_connection()
try:
await protocol.handshake(
self._wsuri,
origin=protocol.origin,
available_extensions=protocol.available_extensions,
available_subprotocols=protocol.available_subprotocols,
extra_headers=protocol.extra_headers,
)
except RedirectHandshake as exc:
protocol.fail_connection()
await protocol.wait_closed()
self.handle_redirect(exc.uri)
# Avoid leaking a connected socket when the handshake fails.
except (Exception, asyncio.CancelledError):
protocol.fail_connection()
await protocol.wait_closed()
raise
else:
self.protocol = protocol
return protocol
else:
raise SecurityError("too many redirects")
# ... = yield from connect(...) - remove when dropping Python < 3.10
__iter__ = __await__
connect = Connect
def unix_connect(
path: str | None = None,
uri: str = "ws://localhost/",
**kwargs: Any,
) -> Connect:
"""
Similar to :func:`connect`, but for connecting to a Unix socket.
This function builds upon the event loop's
:meth:`~asyncio.loop.create_unix_connection` method.
It is only available on Unix.
It's mainly useful for debugging servers listening on Unix sockets.
Args:
path: File system path to the Unix socket.
uri: URI of the WebSocket server; the host is used in the TLS
handshake for secure connections and in the ``Host`` header.
"""
return connect(uri=uri, path=path, unix=True, **kwargs)

View File

@ -0,0 +1,78 @@
import http
from .. import datastructures
from ..exceptions import (
InvalidHandshake,
ProtocolError as WebSocketProtocolError, # noqa: F401
)
from ..typing import StatusLike
class InvalidMessage(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
"""
class InvalidStatusCode(InvalidHandshake):
"""
Raised when a handshake response status code is invalid.
"""
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
self.status_code = status_code
self.headers = headers
def __str__(self) -> str:
return f"server rejected WebSocket connection: HTTP {self.status_code}"
class AbortHandshake(InvalidHandshake):
"""
Raised to abort the handshake on purpose and return an HTTP response.
This exception is an implementation detail.
The public API is
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
Attributes:
status (~http.HTTPStatus): HTTP status code.
headers (Headers): HTTP response headers.
body (bytes): HTTP response body.
"""
def __init__(
self,
status: StatusLike,
headers: datastructures.HeadersLike,
body: bytes = b"",
) -> None:
# If a user passes an int instead of an HTTPStatus, fix it automatically.
self.status = http.HTTPStatus(status)
self.headers = datastructures.Headers(headers)
self.body = body
def __str__(self) -> str:
return (
f"HTTP {self.status:d}, "
f"{len(self.headers)} headers, "
f"{len(self.body)} bytes"
)
class RedirectHandshake(InvalidHandshake):
"""
Raised when a handshake gets redirected.
This exception is an implementation detail.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
def __str__(self) -> str:
return f"redirect to {self.uri}"

View File

@ -0,0 +1,225 @@
from __future__ import annotations
import struct
from collections.abc import Awaitable, Sequence
from typing import Any, Callable, NamedTuple
from .. import extensions, frames
from ..exceptions import PayloadTooBig, ProtocolError
from ..frames import BytesLike
from ..typing import Data
try:
from ..speedups import apply_mask
except ImportError:
from ..utils import apply_mask
class Frame(NamedTuple):
fin: bool
opcode: frames.Opcode
data: bytes
rsv1: bool = False
rsv2: bool = False
rsv3: bool = False
@property
def new_frame(self) -> frames.Frame:
return frames.Frame(
self.opcode,
self.data,
self.fin,
self.rsv1,
self.rsv2,
self.rsv3,
)
def __str__(self) -> str:
return str(self.new_frame)
def check(self) -> None:
return self.new_frame.check()
@classmethod
async def read(
cls,
reader: Callable[[int], Awaitable[bytes]],
*,
mask: bool,
max_size: int | None = None,
extensions: Sequence[extensions.Extension] | None = None,
) -> Frame:
"""
Read a WebSocket frame.
Args:
reader: Coroutine that reads exactly the requested number of
bytes, unless the end of file is reached.
mask: Whether the frame should be masked i.e. whether the read
happens on the server side.
max_size: Maximum payload size in bytes.
extensions: List of extensions, applied in reverse order.
Raises:
PayloadTooBig: If the frame exceeds ``max_size``.
ProtocolError: If the frame contains incorrect values.
"""
# Read the header.
data = await reader(2)
head1, head2 = struct.unpack("!BB", data)
# While not Pythonic, this is marginally faster than calling bool().
fin = True if head1 & 0b10000000 else False
rsv1 = True if head1 & 0b01000000 else False
rsv2 = True if head1 & 0b00100000 else False
rsv3 = True if head1 & 0b00010000 else False
try:
opcode = frames.Opcode(head1 & 0b00001111)
except ValueError as exc:
raise ProtocolError("invalid opcode") from exc
if (True if head2 & 0b10000000 else False) != mask:
raise ProtocolError("incorrect masking")
length = head2 & 0b01111111
if length == 126:
data = await reader(2)
(length,) = struct.unpack("!H", data)
elif length == 127:
data = await reader(8)
(length,) = struct.unpack("!Q", data)
if max_size is not None and length > max_size:
raise PayloadTooBig(length, max_size)
if mask:
mask_bits = await reader(4)
# Read the data.
data = await reader(length)
if mask:
data = apply_mask(data, mask_bits)
new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
if extensions is None:
extensions = []
for extension in reversed(extensions):
new_frame = extension.decode(new_frame, max_size=max_size)
new_frame.check()
return cls(
new_frame.fin,
new_frame.opcode,
new_frame.data,
new_frame.rsv1,
new_frame.rsv2,
new_frame.rsv3,
)
def write(
self,
write: Callable[[bytes], Any],
*,
mask: bool,
extensions: Sequence[extensions.Extension] | None = None,
) -> None:
"""
Write a WebSocket frame.
Args:
frame: Frame to write.
write: Function that writes bytes.
mask: Whether the frame should be masked i.e. whether the write
happens on the client side.
extensions: List of extensions, applied in order.
Raises:
ProtocolError: If the frame contains incorrect values.
"""
# The frame is written in a single call to write in order to prevent
# TCP fragmentation. See #68 for details. This also makes it safe to
# send frames concurrently from multiple coroutines.
write(self.new_frame.serialize(mask=mask, extensions=extensions))
def prepare_data(data: Data) -> tuple[int, bytes]:
"""
Convert a string or byte-like object to an opcode and a bytes-like object.
This function is designed for data frames.
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
object encoding ``data`` in UTF-8.
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
object.
Raises:
TypeError: If ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return frames.Opcode.TEXT, data.encode()
elif isinstance(data, BytesLike):
return frames.Opcode.BINARY, data
else:
raise TypeError("data must be str or bytes-like")
def prepare_ctrl(data: Data) -> bytes:
"""
Convert a string or byte-like object to bytes.
This function is designed for ping and pong frames.
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
``data`` in UTF-8.
If ``data`` is a bytes-like object, return a :class:`bytes` object.
Raises:
TypeError: If ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return data.encode()
elif isinstance(data, BytesLike):
return bytes(data)
else:
raise TypeError("data must be str or bytes-like")
# Backwards compatibility with previously documented public APIs
encode_data = prepare_ctrl
# Backwards compatibility with previously documented public APIs
from ..frames import Close # noqa: E402 F401, I001
def parse_close(data: bytes) -> tuple[int, str]:
"""
Parse the payload from a close frame.
Returns:
Close code and reason.
Raises:
ProtocolError: If data is ill-formed.
UnicodeDecodeError: If the reason isn't valid UTF-8.
"""
close = Close.parse(data)
return close.code, close.reason
def serialize_close(code: int, reason: str) -> bytes:
"""
Serialize the payload for a close frame.
"""
return Close(code, reason).serialize()

View File

@ -0,0 +1,158 @@
from __future__ import annotations
import base64
import binascii
from ..datastructures import Headers, MultipleValuesError
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
from ..headers import parse_connection, parse_upgrade
from ..typing import ConnectionOption, UpgradeProtocol
from ..utils import accept_key as accept, generate_key
__all__ = ["build_request", "check_request", "build_response", "check_response"]
def build_request(headers: Headers) -> str:
"""
Build a handshake request to send to the server.
Update request headers passed in argument.
Args:
headers: Handshake request headers.
Returns:
``key`` that must be passed to :func:`check_response`.
"""
key = generate_key()
headers["Upgrade"] = "websocket"
headers["Connection"] = "Upgrade"
headers["Sec-WebSocket-Key"] = key
headers["Sec-WebSocket-Version"] = "13"
return key
def check_request(headers: Headers) -> str:
"""
Check a handshake request received from the client.
This function doesn't verify that the request is an HTTP/1.1 or higher GET
request and doesn't perform ``Host`` and ``Origin`` checks. These controls
are usually performed earlier in the HTTP request handling code. They're
the responsibility of the caller.
Args:
headers: Handshake request headers.
Returns:
``key`` that must be passed to :func:`build_response`.
Raises:
InvalidHandshake: If the handshake request is invalid.
Then, the server must return a 400 Bad Request error.
"""
connection: list[ConnectionOption] = sum(
[parse_connection(value) for value in headers.get_all("Connection")], []
)
if not any(value.lower() == "upgrade" for value in connection):
raise InvalidUpgrade("Connection", ", ".join(connection))
upgrade: list[UpgradeProtocol] = sum(
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
)
# For compatibility with non-strict implementations, ignore case when
# checking the Upgrade header. The RFC always uses "websocket", except
# in section 11.2. (IANA registration) where it uses "WebSocket".
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
try:
s_w_key = headers["Sec-WebSocket-Key"]
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Key") from exc
except MultipleValuesError as exc:
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
try:
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
except binascii.Error as exc:
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
if len(raw_key) != 16:
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
try:
s_w_version = headers["Sec-WebSocket-Version"]
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Version") from exc
except MultipleValuesError as exc:
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
if s_w_version != "13":
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
return s_w_key
def build_response(headers: Headers, key: str) -> None:
"""
Build a handshake response to send to the client.
Update response headers passed in argument.
Args:
headers: Handshake response headers.
key: Returned by :func:`check_request`.
"""
headers["Upgrade"] = "websocket"
headers["Connection"] = "Upgrade"
headers["Sec-WebSocket-Accept"] = accept(key)
def check_response(headers: Headers, key: str) -> None:
"""
Check a handshake response received from the server.
This function doesn't verify that the response is an HTTP/1.1 or higher
response with a 101 status code. These controls are the responsibility of
the caller.
Args:
headers: Handshake response headers.
key: Returned by :func:`build_request`.
Raises:
InvalidHandshake: If the handshake response is invalid.
"""
connection: list[ConnectionOption] = sum(
[parse_connection(value) for value in headers.get_all("Connection")], []
)
if not any(value.lower() == "upgrade" for value in connection):
raise InvalidUpgrade("Connection", " ".join(connection))
upgrade: list[UpgradeProtocol] = sum(
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
)
# For compatibility with non-strict implementations, ignore case when
# checking the Upgrade header. The RFC always uses "websocket", except
# in section 11.2. (IANA registration) where it uses "WebSocket".
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
try:
s_w_accept = headers["Sec-WebSocket-Accept"]
except KeyError as exc:
raise InvalidHeader("Sec-WebSocket-Accept") from exc
except MultipleValuesError as exc:
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
if s_w_accept != accept(key):
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)

View File

@ -0,0 +1,201 @@
from __future__ import annotations
import asyncio
import os
import re
from ..datastructures import Headers
from ..exceptions import SecurityError
__all__ = ["read_request", "read_response"]
MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
def d(value: bytes) -> str:
"""
Decode a bytestring for interpolating into an error message.
"""
return value.decode(errors="backslashreplace")
# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
# Regex for validating header names.
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
# Regex for validating header values.
# We don't attempt to support obsolete line folding.
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
# The ABNF is complicated because it attempts to express that optional
# whitespace is ignored. We strip whitespace and don't revalidate that.
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]:
"""
Read an HTTP/1.1 GET request and return ``(path, headers)``.
``path`` isn't URL-decoded or validated in any way.
``path`` and ``headers`` are expected to contain only ASCII characters.
Other characters are represented with surrogate escapes.
:func:`read_request` doesn't attempt to read the request body because
WebSocket handshake requests don't have one. If the request contains a
body, it may be read from ``stream`` after this coroutine returns.
Args:
stream: Input to read the request from.
Raises:
EOFError: If the connection is closed without a full HTTP request.
SecurityError: If the request exceeds a security limit.
ValueError: If the request isn't well formatted.
"""
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
# Parsing is simple because fixed values are expected for method and
# version and because path isn't checked. Since WebSocket software tends
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
try:
request_line = await read_line(stream)
except EOFError as exc:
raise EOFError("connection closed while reading HTTP request line") from exc
try:
method, raw_path, version = request_line.split(b" ", 2)
except ValueError: # not enough values to unpack (expected 3, got 1-2)
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
if method != b"GET":
raise ValueError(f"unsupported HTTP method: {d(method)}")
if version != b"HTTP/1.1":
raise ValueError(f"unsupported HTTP version: {d(version)}")
path = raw_path.decode("ascii", "surrogateescape")
headers = await read_headers(stream)
return path, headers
async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]:
"""
Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
``reason`` and ``headers`` are expected to contain only ASCII characters.
Other characters are represented with surrogate escapes.
:func:`read_request` doesn't attempt to read the response body because
WebSocket handshake responses don't have one. If the response contains a
body, it may be read from ``stream`` after this coroutine returns.
Args:
stream: Input to read the response from.
Raises:
EOFError: If the connection is closed without a full HTTP response.
SecurityError: If the response exceeds a security limit.
ValueError: If the response isn't well formatted.
"""
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
# As in read_request, parsing is simple because a fixed value is expected
# for version, status_code is a 3-digit number, and reason can be ignored.
try:
status_line = await read_line(stream)
except EOFError as exc:
raise EOFError("connection closed while reading HTTP status line") from exc
try:
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
except ValueError: # not enough values to unpack (expected 3, got 1-2)
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
if version != b"HTTP/1.1":
raise ValueError(f"unsupported HTTP version: {d(version)}")
try:
status_code = int(raw_status_code)
except ValueError: # invalid literal for int() with base 10
raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
if not 100 <= status_code < 1000:
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
if not _value_re.fullmatch(raw_reason):
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
reason = raw_reason.decode()
headers = await read_headers(stream)
return status_code, reason, headers
async def read_headers(stream: asyncio.StreamReader) -> Headers:
"""
Read HTTP headers from ``stream``.
Non-ASCII characters are represented with surrogate escapes.
"""
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
# We don't attempt to support obsolete line folding.
headers = Headers()
for _ in range(MAX_NUM_HEADERS + 1):
try:
line = await read_line(stream)
except EOFError as exc:
raise EOFError("connection closed while reading HTTP headers") from exc
if line == b"":
break
try:
raw_name, raw_value = line.split(b":", 1)
except ValueError: # not enough values to unpack (expected 2, got 1)
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
if not _token_re.fullmatch(raw_name):
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
raw_value = raw_value.strip(b" \t")
if not _value_re.fullmatch(raw_value):
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
value = raw_value.decode("ascii", "surrogateescape")
headers[name] = value
else:
raise SecurityError("too many HTTP headers")
return headers
async def read_line(stream: asyncio.StreamReader) -> bytes:
"""
Read a single line from ``stream``.
CRLF is stripped from the return value.
"""
# Security: this is bounded by the StreamReader's limit (default = 32 KiB).
line = await stream.readline()
# Security: this guarantees header values are small (hard-coded = 8 KiB)
if len(line) > MAX_LINE_LENGTH:
raise SecurityError("line too long")
# Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
if not line.endswith(b"\r\n"):
raise EOFError("line without CRLF")
return line[:-2]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff