second commit

This commit is contained in:
2024-12-27 22:31:23 +09:00
parent 2353324570
commit 10a0f110ca
8819 changed files with 1307198 additions and 28 deletions

View File

@ -0,0 +1,105 @@
from .__version__ import __description__, __title__, __version__
from ._api import *
from ._auth import *
from ._client import *
from ._config import *
from ._content import *
from ._exceptions import *
from ._models import *
from ._status_codes import *
from ._transports import *
from ._types import *
from ._urls import *
try:
from ._main import main
except ImportError: # pragma: no cover
def main() -> None: # type: ignore
import sys
print(
"The httpx command line client could not run because the required "
"dependencies were not installed.\nMake sure you've installed "
"everything with: pip install 'httpx[cli]'"
)
sys.exit(1)
__all__ = [
"__description__",
"__title__",
"__version__",
"ASGITransport",
"AsyncBaseTransport",
"AsyncByteStream",
"AsyncClient",
"AsyncHTTPTransport",
"Auth",
"BaseTransport",
"BasicAuth",
"ByteStream",
"Client",
"CloseError",
"codes",
"ConnectError",
"ConnectTimeout",
"CookieConflict",
"Cookies",
"create_ssl_context",
"DecodingError",
"delete",
"DigestAuth",
"get",
"head",
"Headers",
"HTTPError",
"HTTPStatusError",
"HTTPTransport",
"InvalidURL",
"Limits",
"LocalProtocolError",
"main",
"MockTransport",
"NetRCAuth",
"NetworkError",
"options",
"patch",
"PoolTimeout",
"post",
"ProtocolError",
"Proxy",
"ProxyError",
"put",
"QueryParams",
"ReadError",
"ReadTimeout",
"RemoteProtocolError",
"request",
"Request",
"RequestError",
"RequestNotRead",
"Response",
"ResponseNotRead",
"stream",
"StreamClosed",
"StreamConsumed",
"StreamError",
"SyncByteStream",
"Timeout",
"TimeoutException",
"TooManyRedirects",
"TransportError",
"UnsupportedProtocol",
"URL",
"USE_CLIENT_DEFAULT",
"WriteError",
"WriteTimeout",
"WSGITransport",
]
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
setattr(__locals[__name], "__module__", "httpx") # noqa

View File

@ -0,0 +1,3 @@
__title__ = "httpx"
__description__ = "A next generation HTTP client, for Python 3."
__version__ = "0.28.1"

View File

@ -0,0 +1,438 @@
from __future__ import annotations
import typing
from contextlib import contextmanager
from ._client import Client
from ._config import DEFAULT_TIMEOUT_CONFIG
from ._models import Response
from ._types import (
AuthTypes,
CookieTypes,
HeaderTypes,
ProxyTypes,
QueryParamTypes,
RequestContent,
RequestData,
RequestFiles,
TimeoutTypes,
)
from ._urls import URL
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
__all__ = [
"delete",
"get",
"head",
"options",
"patch",
"post",
"put",
"request",
"stream",
]
def request(
method: str,
url: URL | str,
*,
params: QueryParamTypes | None = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
trust_env: bool = True,
) -> Response:
"""
Sends an HTTP request.
**Parameters:**
* **method** - HTTP method for the new `Request` object: `GET`, `OPTIONS`,
`HEAD`, `POST`, `PUT`, `PATCH`, or `DELETE`.
* **url** - URL for the new `Request` object.
* **params** - *(optional)* Query parameters to include in the URL, as a
string, dictionary, or sequence of two-tuples.
* **content** - *(optional)* Binary content to include in the body of the
request, as bytes or a byte iterator.
* **data** - *(optional)* Form data to include in the body of the request,
as a dictionary.
* **files** - *(optional)* A dictionary of upload files to include in the
body of the request.
* **json** - *(optional)* A JSON serializable object to include in the body
of the request.
* **headers** - *(optional)* Dictionary of HTTP headers to include in the
request.
* **cookies** - *(optional)* Dictionary of Cookie items to include in the
request.
* **auth** - *(optional)* An authentication class to use when sending the
request.
* **proxy** - *(optional)* A proxy URL where all the traffic should be routed.
* **timeout** - *(optional)* The timeout configuration to use when sending
the request.
* **follow_redirects** - *(optional)* Enables or disables HTTP redirects.
* **verify** - *(optional)* Either `True` to use an SSL context with the
default CA bundle, `False` to disable verification, or an instance of
`ssl.SSLContext` to use a custom context.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
**Returns:** `Response`
Usage:
```
>>> import httpx
>>> response = httpx.request('GET', 'https://httpbin.org/get')
>>> response
<Response [200 OK]>
```
"""
with Client(
cookies=cookies,
proxy=proxy,
verify=verify,
timeout=timeout,
trust_env=trust_env,
) as client:
return client.request(
method=method,
url=url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
auth=auth,
follow_redirects=follow_redirects,
)
@contextmanager
def stream(
method: str,
url: URL | str,
*,
params: QueryParamTypes | None = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
trust_env: bool = True,
) -> typing.Iterator[Response]:
"""
Alternative to `httpx.request()` that streams the response body
instead of loading it into memory at once.
**Parameters**: See `httpx.request`.
See also: [Streaming Responses][0]
[0]: /quickstart#streaming-responses
"""
with Client(
cookies=cookies,
proxy=proxy,
verify=verify,
timeout=timeout,
trust_env=trust_env,
) as client:
with client.stream(
method=method,
url=url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
auth=auth,
follow_redirects=follow_redirects,
) as response:
yield response
def get(
url: URL | str,
*,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
"""
Sends a `GET` request.
**Parameters**: See `httpx.request`.
Note that the `data`, `files`, `json` and `content` parameters are not available
on this function, as `GET` requests should not include a request body.
"""
return request(
"GET",
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)
def options(
url: URL | str,
*,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
"""
Sends an `OPTIONS` request.
**Parameters**: See `httpx.request`.
Note that the `data`, `files`, `json` and `content` parameters are not available
on this function, as `OPTIONS` requests should not include a request body.
"""
return request(
"OPTIONS",
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)
def head(
url: URL | str,
*,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
"""
Sends a `HEAD` request.
**Parameters**: See `httpx.request`.
Note that the `data`, `files`, `json` and `content` parameters are not available
on this function, as `HEAD` requests should not include a request body.
"""
return request(
"HEAD",
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)
def post(
url: URL | str,
*,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
"""
Sends a `POST` request.
**Parameters**: See `httpx.request`.
"""
return request(
"POST",
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)
def put(
url: URL | str,
*,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
"""
Sends a `PUT` request.
**Parameters**: See `httpx.request`.
"""
return request(
"PUT",
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)
def patch(
url: URL | str,
*,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
verify: ssl.SSLContext | str | bool = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
trust_env: bool = True,
) -> Response:
"""
Sends a `PATCH` request.
**Parameters**: See `httpx.request`.
"""
return request(
"PATCH",
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)
def delete(
url: URL | str,
*,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | None = None,
proxy: ProxyTypes | None = None,
follow_redirects: bool = False,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
verify: ssl.SSLContext | str | bool = True,
trust_env: bool = True,
) -> Response:
"""
Sends a `DELETE` request.
**Parameters**: See `httpx.request`.
Note that the `data`, `files`, `json` and `content` parameters are not available
on this function, as `DELETE` requests should not include a request body.
"""
return request(
"DELETE",
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
proxy=proxy,
follow_redirects=follow_redirects,
verify=verify,
timeout=timeout,
trust_env=trust_env,
)

View File

@ -0,0 +1,348 @@
from __future__ import annotations
import hashlib
import os
import re
import time
import typing
from base64 import b64encode
from urllib.request import parse_http_list
from ._exceptions import ProtocolError
from ._models import Cookies, Request, Response
from ._utils import to_bytes, to_str, unquote
if typing.TYPE_CHECKING: # pragma: no cover
from hashlib import _Hash
__all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"]
class Auth:
"""
Base class for all authentication schemes.
To implement a custom authentication scheme, subclass `Auth` and override
the `.auth_flow()` method.
If the authentication scheme does I/O such as disk access or network calls, or uses
synchronization primitives such as locks, you should override `.sync_auth_flow()`
and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
implementations that will be used by `Client` and `AsyncClient` respectively.
"""
requires_request_body = False
requires_response_body = False
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
"""
Execute the authentication flow.
To dispatch a request, `yield` it:
```
yield request
```
The client will `.send()` the response back into the flow generator. You can
access it like so:
```
response = yield request
```
A `return` (or reaching the end of the generator) will result in the
client returning the last response obtained from the server.
You can dispatch as many requests as is necessary.
"""
yield request
def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
"""
Execute the authentication flow synchronously.
By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
request.read()
flow = self.auth_flow(request)
request = next(flow)
while True:
response = yield request
if self.requires_response_body:
response.read()
try:
request = flow.send(response)
except StopIteration:
break
async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
"""
Execute the authentication flow asynchronously.
By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
await request.aread()
flow = self.auth_flow(request)
request = next(flow)
while True:
response = yield request
if self.requires_response_body:
await response.aread()
try:
request = flow.send(response)
except StopIteration:
break
class FunctionAuth(Auth):
"""
Allows the 'auth' argument to be passed as a simple callable function,
that takes the request, and returns a new, modified request.
"""
def __init__(self, func: typing.Callable[[Request], Request]) -> None:
self._func = func
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
yield self._func(request)
class BasicAuth(Auth):
"""
Allows the 'auth' argument to be passed as a (username, password) pair,
and uses HTTP Basic authentication.
"""
def __init__(self, username: str | bytes, password: str | bytes) -> None:
self._auth_header = self._build_auth_header(username, password)
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
request.headers["Authorization"] = self._auth_header
yield request
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode()
return f"Basic {token}"
class NetRCAuth(Auth):
"""
Use a 'netrc' file to lookup basic auth credentials based on the url host.
"""
def __init__(self, file: str | None = None) -> None:
# Lazily import 'netrc'.
# There's no need for us to load this module unless 'NetRCAuth' is being used.
import netrc
self._netrc_info = netrc.netrc(file)
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
auth_info = self._netrc_info.authenticators(request.url.host)
if auth_info is None or not auth_info[2]:
# The netrc file did not have authentication credentials for this host.
yield request
else:
# Build a basic auth header with credentials from the netrc file.
request.headers["Authorization"] = self._build_auth_header(
username=auth_info[0], password=auth_info[2]
)
yield request
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode()
return f"Basic {token}"
class DigestAuth(Auth):
_ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,
"SHA-SESS": hashlib.sha1,
"SHA-256": hashlib.sha256,
"SHA-256-SESS": hashlib.sha256,
"SHA-512": hashlib.sha512,
"SHA-512-SESS": hashlib.sha512,
}
def __init__(self, username: str | bytes, password: str | bytes) -> None:
self._username = to_bytes(username)
self._password = to_bytes(password)
self._last_challenge: _DigestAuthChallenge | None = None
self._nonce_count = 1
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
if self._last_challenge:
request.headers["Authorization"] = self._build_auth_header(
request, self._last_challenge
)
response = yield request
if response.status_code != 401 or "www-authenticate" not in response.headers:
# If the response is not a 401 then we don't
# need to build an authenticated request.
return
for auth_header in response.headers.get_list("www-authenticate"):
if auth_header.lower().startswith("digest "):
break
else:
# If the response does not include a 'WWW-Authenticate: Digest ...'
# header, then we don't need to build an authenticated request.
return
self._last_challenge = self._parse_challenge(request, response, auth_header)
self._nonce_count = 1
request.headers["Authorization"] = self._build_auth_header(
request, self._last_challenge
)
if response.cookies:
Cookies(response.cookies).set_cookie_header(request=request)
yield request
def _parse_challenge(
self, request: Request, response: Response, auth_header: str
) -> _DigestAuthChallenge:
"""
Returns a challenge from a Digest WWW-Authenticate header.
These take the form of:
`Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
"""
scheme, _, fields = auth_header.partition(" ")
# This method should only ever have been called with a Digest auth header.
assert scheme.lower() == "digest"
header_dict: dict[str, str] = {}
for field in parse_http_list(fields):
key, value = field.strip().split("=", 1)
header_dict[key] = unquote(value)
try:
realm = header_dict["realm"].encode()
nonce = header_dict["nonce"].encode()
algorithm = header_dict.get("algorithm", "MD5")
opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
qop = header_dict["qop"].encode() if "qop" in header_dict else None
return _DigestAuthChallenge(
realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop
)
except KeyError as exc:
message = "Malformed Digest WWW-Authenticate header"
raise ProtocolError(message, request=request) from exc
def _build_auth_header(
self, request: Request, challenge: _DigestAuthChallenge
) -> str:
hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()]
def digest(data: bytes) -> bytes:
return hash_func(data).hexdigest().encode()
A1 = b":".join((self._username, challenge.realm, self._password))
path = request.url.raw_path
A2 = b":".join((request.method.encode(), path))
# TODO: implement auth-int
HA2 = digest(A2)
nc_value = b"%08x" % self._nonce_count
cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce)
self._nonce_count += 1
HA1 = digest(A1)
if challenge.algorithm.lower().endswith("-sess"):
HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))
qop = self._resolve_qop(challenge.qop, request=request)
if qop is None:
# Following RFC 2069
digest_data = [HA1, challenge.nonce, HA2]
else:
# Following RFC 2617/7616
digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2]
format_args = {
"username": self._username,
"realm": challenge.realm,
"nonce": challenge.nonce,
"uri": path,
"response": digest(b":".join(digest_data)),
"algorithm": challenge.algorithm.encode(),
}
if challenge.opaque:
format_args["opaque"] = challenge.opaque
if qop:
format_args["qop"] = b"auth"
format_args["nc"] = nc_value
format_args["cnonce"] = cnonce
return "Digest " + self._get_header_value(format_args)
def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
s = str(nonce_count).encode()
s += nonce
s += time.ctime().encode()
s += os.urandom(8)
return hashlib.sha1(s).hexdigest()[:16].encode()
def _get_header_value(self, header_fields: dict[str, bytes]) -> str:
NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
QUOTED_TEMPLATE = '{}="{}"'
NON_QUOTED_TEMPLATE = "{}={}"
header_value = ""
for i, (field, value) in enumerate(header_fields.items()):
if i > 0:
header_value += ", "
template = (
QUOTED_TEMPLATE
if field not in NON_QUOTED_FIELDS
else NON_QUOTED_TEMPLATE
)
header_value += template.format(field, to_str(value))
return header_value
def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None:
if qop is None:
return None
qops = re.split(b", ?", qop)
if b"auth" in qops:
return b"auth"
if qops == [b"auth-int"]:
raise NotImplementedError("Digest auth-int support is not yet implemented")
message = f'Unexpected qop value "{qop!r}" in digest auth'
raise ProtocolError(message, request=request)
class _DigestAuthChallenge(typing.NamedTuple):
realm: bytes
nonce: bytes
algorithm: str
opaque: bytes | None
qop: bytes | None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,248 @@
from __future__ import annotations
import os
import typing
from ._models import Headers
from ._types import CertTypes, HeaderTypes, TimeoutTypes
from ._urls import URL
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
__all__ = ["Limits", "Proxy", "Timeout", "create_ssl_context"]
class UnsetType:
pass # pragma: no cover
UNSET = UnsetType()
def create_ssl_context(
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
) -> ssl.SSLContext:
import ssl
import warnings
import certifi
if verify is True:
if trust_env and os.environ.get("SSL_CERT_FILE"): # pragma: nocover
ctx = ssl.create_default_context(cafile=os.environ["SSL_CERT_FILE"])
elif trust_env and os.environ.get("SSL_CERT_DIR"): # pragma: nocover
ctx = ssl.create_default_context(capath=os.environ["SSL_CERT_DIR"])
else:
# Default case...
ctx = ssl.create_default_context(cafile=certifi.where())
elif verify is False:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
elif isinstance(verify, str): # pragma: nocover
message = (
"`verify=<str>` is deprecated. "
"Use `verify=ssl.create_default_context(cafile=...)` "
"or `verify=ssl.create_default_context(capath=...)` instead."
)
warnings.warn(message, DeprecationWarning)
if os.path.isdir(verify):
return ssl.create_default_context(capath=verify)
return ssl.create_default_context(cafile=verify)
else:
ctx = verify
if cert: # pragma: nocover
message = (
"`cert=...` is deprecated. Use `verify=<ssl_context>` instead,"
"with `.load_cert_chain()` to configure the certificate chain."
)
warnings.warn(message, DeprecationWarning)
if isinstance(cert, str):
ctx.load_cert_chain(cert)
else:
ctx.load_cert_chain(*cert)
return ctx
class Timeout:
"""
Timeout configuration.
**Usage**:
Timeout(None) # No timeouts.
Timeout(5.0) # 5s timeout on all operations.
Timeout(None, connect=5.0) # 5s timeout on connect, no other timeouts.
Timeout(5.0, connect=10.0) # 10s timeout on connect. 5s timeout elsewhere.
Timeout(5.0, pool=None) # No timeout on acquiring connection from pool.
# 5s timeout elsewhere.
"""
def __init__(
self,
timeout: TimeoutTypes | UnsetType = UNSET,
*,
connect: None | float | UnsetType = UNSET,
read: None | float | UnsetType = UNSET,
write: None | float | UnsetType = UNSET,
pool: None | float | UnsetType = UNSET,
) -> None:
if isinstance(timeout, Timeout):
# Passed as a single explicit Timeout.
assert connect is UNSET
assert read is UNSET
assert write is UNSET
assert pool is UNSET
self.connect = timeout.connect # type: typing.Optional[float]
self.read = timeout.read # type: typing.Optional[float]
self.write = timeout.write # type: typing.Optional[float]
self.pool = timeout.pool # type: typing.Optional[float]
elif isinstance(timeout, tuple):
# Passed as a tuple.
self.connect = timeout[0]
self.read = timeout[1]
self.write = None if len(timeout) < 3 else timeout[2]
self.pool = None if len(timeout) < 4 else timeout[3]
elif not (
isinstance(connect, UnsetType)
or isinstance(read, UnsetType)
or isinstance(write, UnsetType)
or isinstance(pool, UnsetType)
):
self.connect = connect
self.read = read
self.write = write
self.pool = pool
else:
if isinstance(timeout, UnsetType):
raise ValueError(
"httpx.Timeout must either include a default, or set all "
"four parameters explicitly."
)
self.connect = timeout if isinstance(connect, UnsetType) else connect
self.read = timeout if isinstance(read, UnsetType) else read
self.write = timeout if isinstance(write, UnsetType) else write
self.pool = timeout if isinstance(pool, UnsetType) else pool
def as_dict(self) -> dict[str, float | None]:
return {
"connect": self.connect,
"read": self.read,
"write": self.write,
"pool": self.pool,
}
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.connect == other.connect
and self.read == other.read
and self.write == other.write
and self.pool == other.pool
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
if len({self.connect, self.read, self.write, self.pool}) == 1:
return f"{class_name}(timeout={self.connect})"
return (
f"{class_name}(connect={self.connect}, "
f"read={self.read}, write={self.write}, pool={self.pool})"
)
class Limits:
"""
Configuration for limits to various client behaviors.
**Parameters:**
* **max_connections** - The maximum number of concurrent connections that may be
established.
* **max_keepalive_connections** - Allow the connection pool to maintain
keep-alive connections below this point. Should be less than or equal
to `max_connections`.
* **keepalive_expiry** - Time limit on idle keep-alive connections in seconds.
"""
def __init__(
self,
*,
max_connections: int | None = None,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = 5.0,
) -> None:
self.max_connections = max_connections
self.max_keepalive_connections = max_keepalive_connections
self.keepalive_expiry = keepalive_expiry
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.max_connections == other.max_connections
and self.max_keepalive_connections == other.max_keepalive_connections
and self.keepalive_expiry == other.keepalive_expiry
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
return (
f"{class_name}(max_connections={self.max_connections}, "
f"max_keepalive_connections={self.max_keepalive_connections}, "
f"keepalive_expiry={self.keepalive_expiry})"
)
class Proxy:
def __init__(
self,
url: URL | str,
*,
ssl_context: ssl.SSLContext | None = None,
auth: tuple[str, str] | None = None,
headers: HeaderTypes | None = None,
) -> None:
url = URL(url)
headers = Headers(headers)
if url.scheme not in ("http", "https", "socks5", "socks5h"):
raise ValueError(f"Unknown scheme for proxy URL {url!r}")
if url.username or url.password:
# Remove any auth credentials from the URL.
auth = (url.username, url.password)
url = url.copy_with(username=None, password=None)
self.url = url
self.auth = auth
self.headers = headers
self.ssl_context = ssl_context
@property
def raw_auth(self) -> tuple[bytes, bytes] | None:
# The proxy authentication as raw bytes.
return (
None
if self.auth is None
else (self.auth[0].encode("utf-8"), self.auth[1].encode("utf-8"))
)
def __repr__(self) -> str:
# The authentication is represented with the password component masked.
auth = (self.auth[0], "********") if self.auth else None
# Build a nice concise representation.
url_str = f"{str(self.url)!r}"
auth_str = f", auth={auth!r}" if auth else ""
headers_str = f", headers={dict(self.headers)!r}" if self.headers else ""
return f"Proxy({url_str}{auth_str}{headers_str})"
DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0)
DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20)
DEFAULT_MAX_REDIRECTS = 20

View File

@ -0,0 +1,240 @@
from __future__ import annotations
import inspect
import warnings
from json import dumps as json_dumps
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Iterable,
Iterator,
Mapping,
)
from urllib.parse import urlencode
from ._exceptions import StreamClosed, StreamConsumed
from ._multipart import MultipartStream
from ._types import (
AsyncByteStream,
RequestContent,
RequestData,
RequestFiles,
ResponseContent,
SyncByteStream,
)
from ._utils import peek_filelike_length, primitive_value_to_str
__all__ = ["ByteStream"]
class ByteStream(AsyncByteStream, SyncByteStream):
def __init__(self, stream: bytes) -> None:
self._stream = stream
def __iter__(self) -> Iterator[bytes]:
yield self._stream
async def __aiter__(self) -> AsyncIterator[bytes]:
yield self._stream
class IteratorByteStream(SyncByteStream):
CHUNK_SIZE = 65_536
def __init__(self, stream: Iterable[bytes]) -> None:
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isgenerator(stream)
def __iter__(self) -> Iterator[bytes]:
if self._is_stream_consumed and self._is_generator:
raise StreamConsumed()
self._is_stream_consumed = True
if hasattr(self._stream, "read"):
# File-like interfaces should use 'read' directly.
chunk = self._stream.read(self.CHUNK_SIZE)
while chunk:
yield chunk
chunk = self._stream.read(self.CHUNK_SIZE)
else:
# Otherwise iterate.
for part in self._stream:
yield part
class AsyncIteratorByteStream(AsyncByteStream):
CHUNK_SIZE = 65_536
def __init__(self, stream: AsyncIterable[bytes]) -> None:
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isasyncgen(stream)
async def __aiter__(self) -> AsyncIterator[bytes]:
if self._is_stream_consumed and self._is_generator:
raise StreamConsumed()
self._is_stream_consumed = True
if hasattr(self._stream, "aread"):
# File-like interfaces should use 'aread' directly.
chunk = await self._stream.aread(self.CHUNK_SIZE)
while chunk:
yield chunk
chunk = await self._stream.aread(self.CHUNK_SIZE)
else:
# Otherwise iterate.
async for part in self._stream:
yield part
class UnattachedStream(AsyncByteStream, SyncByteStream):
"""
If a request or response is serialized using pickle, then it is no longer
attached to a stream for I/O purposes. Any stream operations should result
in `httpx.StreamClosed`.
"""
def __iter__(self) -> Iterator[bytes]:
raise StreamClosed()
async def __aiter__(self) -> AsyncIterator[bytes]:
raise StreamClosed()
yield b"" # pragma: no cover
def encode_content(
content: str | bytes | Iterable[bytes] | AsyncIterable[bytes],
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
if isinstance(content, (bytes, str)):
body = content.encode("utf-8") if isinstance(content, str) else content
content_length = len(body)
headers = {"Content-Length": str(content_length)} if body else {}
return headers, ByteStream(body)
elif isinstance(content, Iterable) and not isinstance(content, dict):
# `not isinstance(content, dict)` is a bit oddly specific, but it
# catches a case that's easy for users to make in error, and would
# otherwise pass through here, like any other bytes-iterable,
# because `dict` happens to be iterable. See issue #2491.
content_length_or_none = peek_filelike_length(content)
if content_length_or_none is None:
headers = {"Transfer-Encoding": "chunked"}
else:
headers = {"Content-Length": str(content_length_or_none)}
return headers, IteratorByteStream(content) # type: ignore
elif isinstance(content, AsyncIterable):
headers = {"Transfer-Encoding": "chunked"}
return headers, AsyncIteratorByteStream(content)
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
def encode_urlencoded_data(
data: RequestData,
) -> tuple[dict[str, str], ByteStream]:
plain_data = []
for key, value in data.items():
if isinstance(value, (list, tuple)):
plain_data.extend([(key, primitive_value_to_str(item)) for item in value])
else:
plain_data.append((key, primitive_value_to_str(value)))
body = urlencode(plain_data, doseq=True).encode("utf-8")
content_length = str(len(body))
content_type = "application/x-www-form-urlencoded"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, ByteStream(body)
def encode_multipart_data(
data: RequestData, files: RequestFiles, boundary: bytes | None
) -> tuple[dict[str, str], MultipartStream]:
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
return headers, multipart
def encode_text(text: str) -> tuple[dict[str, str], ByteStream]:
body = text.encode("utf-8")
content_length = str(len(body))
content_type = "text/plain; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, ByteStream(body)
def encode_html(html: str) -> tuple[dict[str, str], ByteStream]:
body = html.encode("utf-8")
content_length = str(len(body))
content_type = "text/html; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, ByteStream(body)
def encode_json(json: Any) -> tuple[dict[str, str], ByteStream]:
body = json_dumps(
json, ensure_ascii=False, separators=(",", ":"), allow_nan=False
).encode("utf-8")
content_length = str(len(body))
content_type = "application/json"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, ByteStream(body)
def encode_request(
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: Any | None = None,
boundary: bytes | None = None,
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
returning a two-tuple of (<headers>, <stream>).
"""
if data is not None and not isinstance(data, Mapping):
# We prefer to separate `content=<bytes|str|byte iterator|bytes aiterator>`
# for raw request content, and `data=<form data>` for url encoded or
# multipart form content.
#
# However for compat with requests, we *do* still support
# `data=<bytes...>` usages. We deal with that case here, treating it
# as if `content=<...>` had been supplied instead.
message = "Use 'content=<...>' to upload raw bytes/text content."
warnings.warn(message, DeprecationWarning, stacklevel=2)
return encode_content(data)
if content is not None:
return encode_content(content)
elif files:
return encode_multipart_data(data or {}, files, boundary)
elif data:
return encode_urlencoded_data(data)
elif json is not None:
return encode_json(json)
return {}, ByteStream(b"")
def encode_response(
content: ResponseContent | None = None,
text: str | None = None,
html: str | None = None,
json: Any | None = None,
) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]:
"""
Handles encoding the given `content`, returning a two-tuple of
(<headers>, <stream>).
"""
if content is not None:
return encode_content(content)
elif text is not None:
return encode_text(text)
elif html is not None:
return encode_html(html)
elif json is not None:
return encode_json(json)
return {}, ByteStream(b"")

View File

@ -0,0 +1,393 @@
"""
Handlers for Content-Encoding.
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
"""
from __future__ import annotations
import codecs
import io
import typing
import zlib
from ._exceptions import DecodingError
# Brotli support is optional
try:
# The C bindings in `brotli` are recommended for CPython.
import brotli
except ImportError: # pragma: no cover
try:
# The CFFI bindings in `brotlicffi` are recommended for PyPy
# and other environments.
import brotlicffi as brotli
except ImportError:
brotli = None
# Zstandard support is optional
try:
import zstandard
except ImportError: # pragma: no cover
zstandard = None # type: ignore
class ContentDecoder:
def decode(self, data: bytes) -> bytes:
raise NotImplementedError() # pragma: no cover
def flush(self) -> bytes:
raise NotImplementedError() # pragma: no cover
class IdentityDecoder(ContentDecoder):
"""
Handle unencoded data.
"""
def decode(self, data: bytes) -> bytes:
return data
def flush(self) -> bytes:
return b""
class DeflateDecoder(ContentDecoder):
"""
Handle 'deflate' decoding.
See: https://stackoverflow.com/questions/1838699
"""
def __init__(self) -> None:
self.first_attempt = True
self.decompressor = zlib.decompressobj()
def decode(self, data: bytes) -> bytes:
was_first_attempt = self.first_attempt
self.first_attempt = False
try:
return self.decompressor.decompress(data)
except zlib.error as exc:
if was_first_attempt:
self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
return self.decode(data)
raise DecodingError(str(exc)) from exc
def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: no cover
raise DecodingError(str(exc)) from exc
class GZipDecoder(ContentDecoder):
"""
Handle 'gzip' decoding.
See: https://stackoverflow.com/questions/1838699
"""
def __init__(self) -> None:
self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16)
def decode(self, data: bytes) -> bytes:
try:
return self.decompressor.decompress(data)
except zlib.error as exc:
raise DecodingError(str(exc)) from exc
def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: no cover
raise DecodingError(str(exc)) from exc
class BrotliDecoder(ContentDecoder):
"""
Handle 'brotli' decoding.
Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/
or `pip install brotli`. See https://github.com/google/brotli
Supports both 'brotlipy' and 'Brotli' packages since they share an import
name. The top branches are for 'brotlipy' and bottom branches for 'Brotli'
"""
def __init__(self) -> None:
if brotli is None: # pragma: no cover
raise ImportError(
"Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' "
"packages have been installed. "
"Make sure to install httpx using `pip install httpx[brotli]`."
) from None
self.decompressor = brotli.Decompressor()
self.seen_data = False
self._decompress: typing.Callable[[bytes], bytes]
if hasattr(self.decompressor, "decompress"):
# The 'brotlicffi' package.
self._decompress = self.decompressor.decompress # pragma: no cover
else:
# The 'brotli' package.
self._decompress = self.decompressor.process # pragma: no cover
def decode(self, data: bytes) -> bytes:
if not data:
return b""
self.seen_data = True
try:
return self._decompress(data)
except brotli.error as exc:
raise DecodingError(str(exc)) from exc
def flush(self) -> bytes:
if not self.seen_data:
return b""
try:
if hasattr(self.decompressor, "finish"):
# Only available in the 'brotlicffi' package.
# As the decompressor decompresses eagerly, this
# will never actually emit any data. However, it will potentially throw
# errors if a truncated or damaged data stream has been used.
self.decompressor.finish() # pragma: no cover
return b""
except brotli.error as exc: # pragma: no cover
raise DecodingError(str(exc)) from exc
class ZStandardDecoder(ContentDecoder):
"""
Handle 'zstd' RFC 8878 decoding.
Requires `pip install zstandard`.
Can be installed as a dependency of httpx using `pip install httpx[zstd]`.
"""
# inspired by the ZstdDecoder implementation in urllib3
def __init__(self) -> None:
if zstandard is None: # pragma: no cover
raise ImportError(
"Using 'ZStandardDecoder', ..."
"Make sure to install httpx using `pip install httpx[zstd]`."
) from None
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
self.seen_data = False
def decode(self, data: bytes) -> bytes:
assert zstandard is not None
self.seen_data = True
output = io.BytesIO()
try:
output.write(self.decompressor.decompress(data))
while self.decompressor.eof and self.decompressor.unused_data:
unused_data = self.decompressor.unused_data
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
output.write(self.decompressor.decompress(unused_data))
except zstandard.ZstdError as exc:
raise DecodingError(str(exc)) from exc
return output.getvalue()
def flush(self) -> bytes:
if not self.seen_data:
return b""
ret = self.decompressor.flush() # note: this is a no-op
if not self.decompressor.eof:
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
return bytes(ret)
class MultiDecoder(ContentDecoder):
"""
Handle the case where multiple encodings have been applied.
"""
def __init__(self, children: typing.Sequence[ContentDecoder]) -> None:
"""
'children' should be a sequence of decoders in the order in which
each was applied.
"""
# Note that we reverse the order for decoding.
self.children = list(reversed(children))
def decode(self, data: bytes) -> bytes:
for child in self.children:
data = child.decode(data)
return data
def flush(self) -> bytes:
data = b""
for child in self.children:
data = child.decode(data) + child.flush()
return data
class ByteChunker:
"""
Handles returning byte content in fixed-size chunks.
"""
def __init__(self, chunk_size: int | None = None) -> None:
self._buffer = io.BytesIO()
self._chunk_size = chunk_size
def decode(self, content: bytes) -> list[bytes]:
if self._chunk_size is None:
return [content] if content else []
self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []
def flush(self) -> list[bytes]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []
class TextChunker:
"""
Handles returning text content in fixed-size chunks.
"""
def __init__(self, chunk_size: int | None = None) -> None:
self._buffer = io.StringIO()
self._chunk_size = chunk_size
def decode(self, content: str) -> list[str]:
if self._chunk_size is None:
return [content] if content else []
self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []
def flush(self) -> list[str]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []
class TextDecoder:
"""
Handles incrementally decoding bytes into text
"""
def __init__(self, encoding: str = "utf-8") -> None:
self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace")
def decode(self, data: bytes) -> str:
return self.decoder.decode(data)
def flush(self) -> str:
return self.decoder.decode(b"", True)
class LineDecoder:
"""
Handles incrementally reading lines from text.
Has the same behaviour as the stdllib splitlines,
but handling the input iteratively.
"""
def __init__(self) -> None:
self.buffer: list[str] = []
self.trailing_cr: bool = False
def decode(self, text: str) -> list[str]:
# See https://docs.python.org/3/library/stdtypes.html#str.splitlines
NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029"
# We always push a trailing `\r` into the next decode iteration.
if self.trailing_cr:
text = "\r" + text
self.trailing_cr = False
if text.endswith("\r"):
self.trailing_cr = True
text = text[:-1]
if not text:
# NOTE: the edge case input of empty text doesn't occur in practice,
# because other httpx internals filter out this value
return [] # pragma: no cover
trailing_newline = text[-1] in NEWLINE_CHARS
lines = text.splitlines()
if len(lines) == 1 and not trailing_newline:
# No new lines, buffer the input and continue.
self.buffer.append(lines[0])
return []
if self.buffer:
# Include any existing buffer in the first portion of the
# splitlines result.
lines = ["".join(self.buffer) + lines[0]] + lines[1:]
self.buffer = []
if not trailing_newline:
# If the last segment of splitlines is not newline terminated,
# then drop it from our output and start a new buffer.
self.buffer = [lines.pop()]
return lines
def flush(self) -> list[str]:
if not self.buffer and not self.trailing_cr:
return []
lines = ["".join(self.buffer)]
self.buffer = []
self.trailing_cr = False
return lines
SUPPORTED_DECODERS = {
"identity": IdentityDecoder,
"gzip": GZipDecoder,
"deflate": DeflateDecoder,
"br": BrotliDecoder,
"zstd": ZStandardDecoder,
}
if brotli is None:
SUPPORTED_DECODERS.pop("br") # pragma: no cover
if zstandard is None:
SUPPORTED_DECODERS.pop("zstd") # pragma: no cover

View File

@ -0,0 +1,379 @@
"""
Our exception hierarchy:
* HTTPError
x RequestError
+ TransportError
- TimeoutException
· ConnectTimeout
· ReadTimeout
· WriteTimeout
· PoolTimeout
- NetworkError
· ConnectError
· ReadError
· WriteError
· CloseError
- ProtocolError
· LocalProtocolError
· RemoteProtocolError
- ProxyError
- UnsupportedProtocol
+ DecodingError
+ TooManyRedirects
x HTTPStatusError
* InvalidURL
* CookieConflict
* StreamError
x StreamConsumed
x StreamClosed
x ResponseNotRead
x RequestNotRead
"""
from __future__ import annotations
import contextlib
import typing
if typing.TYPE_CHECKING:
from ._models import Request, Response # pragma: no cover
__all__ = [
"CloseError",
"ConnectError",
"ConnectTimeout",
"CookieConflict",
"DecodingError",
"HTTPError",
"HTTPStatusError",
"InvalidURL",
"LocalProtocolError",
"NetworkError",
"PoolTimeout",
"ProtocolError",
"ProxyError",
"ReadError",
"ReadTimeout",
"RemoteProtocolError",
"RequestError",
"RequestNotRead",
"ResponseNotRead",
"StreamClosed",
"StreamConsumed",
"StreamError",
"TimeoutException",
"TooManyRedirects",
"TransportError",
"UnsupportedProtocol",
"WriteError",
"WriteTimeout",
]
class HTTPError(Exception):
"""
Base class for `RequestError` and `HTTPStatusError`.
Useful for `try...except` blocks when issuing a request,
and then calling `.raise_for_status()`.
For example:
```
try:
response = httpx.get("https://www.example.com")
response.raise_for_status()
except httpx.HTTPError as exc:
print(f"HTTP Exception for {exc.request.url} - {exc}")
```
"""
def __init__(self, message: str) -> None:
super().__init__(message)
self._request: Request | None = None
@property
def request(self) -> Request:
if self._request is None:
raise RuntimeError("The .request property has not been set.")
return self._request
@request.setter
def request(self, request: Request) -> None:
self._request = request
class RequestError(HTTPError):
"""
Base class for all exceptions that may occur when issuing a `.request()`.
"""
def __init__(self, message: str, *, request: Request | None = None) -> None:
super().__init__(message)
# At the point an exception is raised we won't typically have a request
# instance to associate it with.
#
# The 'request_context' context manager is used within the Client and
# Response methods in order to ensure that any raised exceptions
# have a `.request` property set on them.
self._request = request
class TransportError(RequestError):
"""
Base class for all exceptions that occur at the level of the Transport API.
"""
# Timeout exceptions...
class TimeoutException(TransportError):
"""
The base class for timeout errors.
An operation has timed out.
"""
class ConnectTimeout(TimeoutException):
"""
Timed out while connecting to the host.
"""
class ReadTimeout(TimeoutException):
"""
Timed out while receiving data from the host.
"""
class WriteTimeout(TimeoutException):
"""
Timed out while sending data to the host.
"""
class PoolTimeout(TimeoutException):
"""
Timed out waiting to acquire a connection from the pool.
"""
# Core networking exceptions...
class NetworkError(TransportError):
"""
The base class for network-related errors.
An error occurred while interacting with the network.
"""
class ReadError(NetworkError):
"""
Failed to receive data from the network.
"""
class WriteError(NetworkError):
"""
Failed to send data through the network.
"""
class ConnectError(NetworkError):
"""
Failed to establish a connection.
"""
class CloseError(NetworkError):
"""
Failed to close a connection.
"""
# Other transport exceptions...
class ProxyError(TransportError):
"""
An error occurred while establishing a proxy connection.
"""
class UnsupportedProtocol(TransportError):
"""
Attempted to make a request to an unsupported protocol.
For example issuing a request to `ftp://www.example.com`.
"""
class ProtocolError(TransportError):
"""
The protocol was violated.
"""
class LocalProtocolError(ProtocolError):
"""
A protocol was violated by the client.
For example if the user instantiated a `Request` instance explicitly,
failed to include the mandatory `Host:` header, and then issued it directly
using `client.send()`.
"""
class RemoteProtocolError(ProtocolError):
"""
The protocol was violated by the server.
For example, returning malformed HTTP.
"""
# Other request exceptions...
class DecodingError(RequestError):
"""
Decoding of the response failed, due to a malformed encoding.
"""
class TooManyRedirects(RequestError):
"""
Too many redirects.
"""
# Client errors
class HTTPStatusError(HTTPError):
"""
The response had an error HTTP status of 4xx or 5xx.
May be raised when calling `response.raise_for_status()`
"""
def __init__(self, message: str, *, request: Request, response: Response) -> None:
super().__init__(message)
self.request = request
self.response = response
class InvalidURL(Exception):
"""
URL is improperly formed or cannot be parsed.
"""
def __init__(self, message: str) -> None:
super().__init__(message)
class CookieConflict(Exception):
"""
Attempted to lookup a cookie by name, but multiple cookies existed.
Can occur when calling `response.cookies.get(...)`.
"""
def __init__(self, message: str) -> None:
super().__init__(message)
# Stream exceptions...
# These may occur as the result of a programming error, by accessing
# the request/response stream in an invalid manner.
class StreamError(RuntimeError):
"""
The base class for stream exceptions.
The developer made an error in accessing the request stream in
an invalid way.
"""
def __init__(self, message: str) -> None:
super().__init__(message)
class StreamConsumed(StreamError):
"""
Attempted to read or stream content, but the content has already
been streamed.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream some content, but the content has "
"already been streamed. For requests, this could be due to passing "
"a generator as request content, and then receiving a redirect "
"response or a secondary request as part of an authentication flow."
"For responses, this could be due to attempting to stream the response "
"content more than once."
)
super().__init__(message)
class StreamClosed(StreamError):
"""
Attempted to read or stream response content, but the request has been
closed.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream content, but the stream has " "been closed."
)
super().__init__(message)
class ResponseNotRead(StreamError):
"""
Attempted to access streaming response content, without having called `read()`.
"""
def __init__(self) -> None:
message = (
"Attempted to access streaming response content,"
" without having called `read()`."
)
super().__init__(message)
class RequestNotRead(StreamError):
"""
Attempted to access streaming request content, without having called `read()`.
"""
def __init__(self) -> None:
message = (
"Attempted to access streaming request content,"
" without having called `read()`."
)
super().__init__(message)
@contextlib.contextmanager
def request_context(
request: Request | None = None,
) -> typing.Iterator[None]:
"""
A context manager that can be used to attach the given request context
to any `RequestError` exceptions that are raised within the block.
"""
try:
yield
except RequestError as exc:
if request is not None:
exc.request = request
raise exc

View File

@ -0,0 +1,506 @@
from __future__ import annotations
import functools
import json
import sys
import typing
import click
import pygments.lexers
import pygments.util
import rich.console
import rich.markup
import rich.progress
import rich.syntax
import rich.table
from ._client import Client
from ._exceptions import RequestError
from ._models import Response
from ._status_codes import codes
if typing.TYPE_CHECKING:
import httpcore # pragma: no cover
def print_help() -> None:
console = rich.console.Console()
console.print("[bold]HTTPX :butterfly:", justify="center")
console.print()
console.print("A next generation HTTP client.", justify="center")
console.print()
console.print(
"Usage: [bold]httpx[/bold] [cyan]<URL> [OPTIONS][/cyan] ", justify="left"
)
console.print()
table = rich.table.Table.grid(padding=1, pad_edge=True)
table.add_column("Parameter", no_wrap=True, justify="left", style="bold")
table.add_column("Description")
table.add_row(
"-m, --method [cyan]METHOD",
"Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD.\n"
"[Default: GET, or POST if a request body is included]",
)
table.add_row(
"-p, --params [cyan]<NAME VALUE> ...",
"Query parameters to include in the request URL.",
)
table.add_row(
"-c, --content [cyan]TEXT", "Byte content to include in the request body."
)
table.add_row(
"-d, --data [cyan]<NAME VALUE> ...", "Form data to include in the request body."
)
table.add_row(
"-f, --files [cyan]<NAME FILENAME> ...",
"Form files to include in the request body.",
)
table.add_row("-j, --json [cyan]TEXT", "JSON data to include in the request body.")
table.add_row(
"-h, --headers [cyan]<NAME VALUE> ...",
"Include additional HTTP headers in the request.",
)
table.add_row(
"--cookies [cyan]<NAME VALUE> ...", "Cookies to include in the request."
)
table.add_row(
"--auth [cyan]<USER PASS>",
"Username and password to include in the request. Specify '-' for the password"
" to use a password prompt. Note that using --verbose/-v will expose"
" the Authorization header, including the password encoding"
" in a trivially reversible format.",
)
table.add_row(
"--proxy [cyan]URL",
"Send the request via a proxy. Should be the URL giving the proxy address.",
)
table.add_row(
"--timeout [cyan]FLOAT",
"Timeout value to use for network operations, such as establishing the"
" connection, reading some data, etc... [Default: 5.0]",
)
table.add_row("--follow-redirects", "Automatically follow redirects.")
table.add_row("--no-verify", "Disable SSL verification.")
table.add_row(
"--http2", "Send the request using HTTP/2, if the remote server supports it."
)
table.add_row(
"--download [cyan]FILE",
"Save the response content as a file, rather than displaying it.",
)
table.add_row("-v, --verbose", "Verbose output. Show request as well as response.")
table.add_row("--help", "Show this message and exit.")
console.print(table)
def get_lexer_for_response(response: Response) -> str:
content_type = response.headers.get("Content-Type")
if content_type is not None:
mime_type, _, _ = content_type.partition(";")
try:
return typing.cast(
str, pygments.lexers.get_lexer_for_mimetype(mime_type.strip()).name
)
except pygments.util.ClassNotFound: # pragma: no cover
pass
return "" # pragma: no cover
def format_request_headers(request: httpcore.Request, http2: bool = False) -> str:
version = "HTTP/2" if http2 else "HTTP/1.1"
headers = [
(name.lower() if http2 else name, value) for name, value in request.headers
]
method = request.method.decode("ascii")
target = request.url.target.decode("ascii")
lines = [f"{method} {target} {version}"] + [
f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers
]
return "\n".join(lines)
def format_response_headers(
http_version: bytes,
status: int,
reason_phrase: bytes | None,
headers: list[tuple[bytes, bytes]],
) -> str:
version = http_version.decode("ascii")
reason = (
codes.get_reason_phrase(status)
if reason_phrase is None
else reason_phrase.decode("ascii")
)
lines = [f"{version} {status} {reason}"] + [
f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers
]
return "\n".join(lines)
def print_request_headers(request: httpcore.Request, http2: bool = False) -> None:
console = rich.console.Console()
http_text = format_request_headers(request, http2=http2)
syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True)
console.print(syntax)
syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True)
console.print(syntax)
def print_response_headers(
http_version: bytes,
status: int,
reason_phrase: bytes | None,
headers: list[tuple[bytes, bytes]],
) -> None:
console = rich.console.Console()
http_text = format_response_headers(http_version, status, reason_phrase, headers)
syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True)
console.print(syntax)
syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True)
console.print(syntax)
def print_response(response: Response) -> None:
console = rich.console.Console()
lexer_name = get_lexer_for_response(response)
if lexer_name:
if lexer_name.lower() == "json":
try:
data = response.json()
text = json.dumps(data, indent=4)
except ValueError: # pragma: no cover
text = response.text
else:
text = response.text
syntax = rich.syntax.Syntax(text, lexer_name, theme="ansi_dark", word_wrap=True)
console.print(syntax)
else:
console.print(f"<{len(response.content)} bytes of binary data>")
_PCTRTT = typing.Tuple[typing.Tuple[str, str], ...]
_PCTRTTT = typing.Tuple[_PCTRTT, ...]
_PeerCertRetDictType = typing.Dict[str, typing.Union[str, _PCTRTTT, _PCTRTT]]
def format_certificate(cert: _PeerCertRetDictType) -> str: # pragma: no cover
lines = []
for key, value in cert.items():
if isinstance(value, (list, tuple)):
lines.append(f"* {key}:")
for item in value:
if key in ("subject", "issuer"):
for sub_item in item:
lines.append(f"* {sub_item[0]}: {sub_item[1]!r}")
elif isinstance(item, tuple) and len(item) == 2:
lines.append(f"* {item[0]}: {item[1]!r}")
else:
lines.append(f"* {item!r}")
else:
lines.append(f"* {key}: {value!r}")
return "\n".join(lines)
def trace(
name: str, info: typing.Mapping[str, typing.Any], verbose: bool = False
) -> None:
console = rich.console.Console()
if name == "connection.connect_tcp.started" and verbose:
host = info["host"]
console.print(f"* Connecting to {host!r}")
elif name == "connection.connect_tcp.complete" and verbose:
stream = info["return_value"]
server_addr = stream.get_extra_info("server_addr")
console.print(f"* Connected to {server_addr[0]!r} on port {server_addr[1]}")
elif name == "connection.start_tls.complete" and verbose: # pragma: no cover
stream = info["return_value"]
ssl_object = stream.get_extra_info("ssl_object")
version = ssl_object.version()
cipher = ssl_object.cipher()
server_cert = ssl_object.getpeercert()
alpn = ssl_object.selected_alpn_protocol()
console.print(f"* SSL established using {version!r} / {cipher[0]!r}")
console.print(f"* Selected ALPN protocol: {alpn!r}")
if server_cert:
console.print("* Server certificate:")
console.print(format_certificate(server_cert))
elif name == "http11.send_request_headers.started" and verbose:
request = info["request"]
print_request_headers(request, http2=False)
elif name == "http2.send_request_headers.started" and verbose: # pragma: no cover
request = info["request"]
print_request_headers(request, http2=True)
elif name == "http11.receive_response_headers.complete":
http_version, status, reason_phrase, headers = info["return_value"]
print_response_headers(http_version, status, reason_phrase, headers)
elif name == "http2.receive_response_headers.complete": # pragma: no cover
status, headers = info["return_value"]
http_version = b"HTTP/2"
reason_phrase = None
print_response_headers(http_version, status, reason_phrase, headers)
def download_response(response: Response, download: typing.BinaryIO) -> None:
console = rich.console.Console()
console.print()
content_length = response.headers.get("Content-Length")
with rich.progress.Progress(
"[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%",
rich.progress.BarColumn(bar_width=None),
rich.progress.DownloadColumn(),
rich.progress.TransferSpeedColumn(),
) as progress:
description = f"Downloading [bold]{rich.markup.escape(download.name)}"
download_task = progress.add_task(
description,
total=int(content_length or 0),
start=content_length is not None,
)
for chunk in response.iter_bytes():
download.write(chunk)
progress.update(download_task, completed=response.num_bytes_downloaded)
def validate_json(
ctx: click.Context,
param: click.Option | click.Parameter,
value: typing.Any,
) -> typing.Any:
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError: # pragma: no cover
raise click.BadParameter("Not valid JSON")
def validate_auth(
ctx: click.Context,
param: click.Option | click.Parameter,
value: typing.Any,
) -> typing.Any:
if value == (None, None):
return None
username, password = value
if password == "-": # pragma: no cover
password = click.prompt("Password", hide_input=True)
return (username, password)
def handle_help(
ctx: click.Context,
param: click.Option | click.Parameter,
value: typing.Any,
) -> None:
if not value or ctx.resilient_parsing:
return
print_help()
ctx.exit()
@click.command(add_help_option=False)
@click.argument("url", type=str)
@click.option(
"--method",
"-m",
"method",
type=str,
help=(
"Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD. "
"[Default: GET, or POST if a request body is included]"
),
)
@click.option(
"--params",
"-p",
"params",
type=(str, str),
multiple=True,
help="Query parameters to include in the request URL.",
)
@click.option(
"--content",
"-c",
"content",
type=str,
help="Byte content to include in the request body.",
)
@click.option(
"--data",
"-d",
"data",
type=(str, str),
multiple=True,
help="Form data to include in the request body.",
)
@click.option(
"--files",
"-f",
"files",
type=(str, click.File(mode="rb")),
multiple=True,
help="Form files to include in the request body.",
)
@click.option(
"--json",
"-j",
"json",
type=str,
callback=validate_json,
help="JSON data to include in the request body.",
)
@click.option(
"--headers",
"-h",
"headers",
type=(str, str),
multiple=True,
help="Include additional HTTP headers in the request.",
)
@click.option(
"--cookies",
"cookies",
type=(str, str),
multiple=True,
help="Cookies to include in the request.",
)
@click.option(
"--auth",
"auth",
type=(str, str),
default=(None, None),
callback=validate_auth,
help=(
"Username and password to include in the request. "
"Specify '-' for the password to use a password prompt. "
"Note that using --verbose/-v will expose the Authorization header, "
"including the password encoding in a trivially reversible format."
),
)
@click.option(
"--proxy",
"proxy",
type=str,
default=None,
help="Send the request via a proxy. Should be the URL giving the proxy address.",
)
@click.option(
"--timeout",
"timeout",
type=float,
default=5.0,
help=(
"Timeout value to use for network operations, such as establishing the "
"connection, reading some data, etc... [Default: 5.0]"
),
)
@click.option(
"--follow-redirects",
"follow_redirects",
is_flag=True,
default=False,
help="Automatically follow redirects.",
)
@click.option(
"--no-verify",
"verify",
is_flag=True,
default=True,
help="Disable SSL verification.",
)
@click.option(
"--http2",
"http2",
type=bool,
is_flag=True,
default=False,
help="Send the request using HTTP/2, if the remote server supports it.",
)
@click.option(
"--download",
type=click.File("wb"),
help="Save the response content as a file, rather than displaying it.",
)
@click.option(
"--verbose",
"-v",
type=bool,
is_flag=True,
default=False,
help="Verbose. Show request as well as response.",
)
@click.option(
"--help",
is_flag=True,
is_eager=True,
expose_value=False,
callback=handle_help,
help="Show this message and exit.",
)
def main(
url: str,
method: str,
params: list[tuple[str, str]],
content: str,
data: list[tuple[str, str]],
files: list[tuple[str, click.File]],
json: str,
headers: list[tuple[str, str]],
cookies: list[tuple[str, str]],
auth: tuple[str, str] | None,
proxy: str,
timeout: float,
follow_redirects: bool,
verify: bool,
http2: bool,
download: typing.BinaryIO | None,
verbose: bool,
) -> None:
"""
An HTTP command line client.
Sends a request and displays the response.
"""
if not method:
method = "POST" if content or data or files or json else "GET"
try:
with Client(proxy=proxy, timeout=timeout, http2=http2, verify=verify) as client:
with client.stream(
method,
url,
params=list(params),
content=content,
data=dict(data),
files=files, # type: ignore
json=json,
headers=headers,
cookies=dict(cookies),
auth=auth,
follow_redirects=follow_redirects,
extensions={"trace": functools.partial(trace, verbose=verbose)},
) as response:
if download is not None:
download_response(response, download)
else:
response.read()
if response.content:
print_response(response)
except RequestError as exc:
console = rich.console.Console()
console.print(f"[red]{type(exc).__name__}[/red]: {exc}")
sys.exit(1)
sys.exit(0 if response.is_success else 1)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,300 @@
from __future__ import annotations
import io
import mimetypes
import os
import re
import typing
from pathlib import Path
from ._types import (
AsyncByteStream,
FileContent,
FileTypes,
RequestData,
RequestFiles,
SyncByteStream,
)
from ._utils import (
peek_filelike_length,
primitive_value_to_str,
to_bytes,
)
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
)
_HTML5_FORM_ENCODING_RE = re.compile(
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
)
def _format_form_param(name: str, value: str) -> bytes:
"""
Encode a name/value pair within a multipart form.
"""
def replacer(match: typing.Match[str]) -> str:
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
return f'{name}="{value}"'.encode()
def _guess_content_type(filename: str | None) -> str | None:
"""
Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
Returns `None` if `filename` is `None` or empty.
"""
if filename:
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
return None
def get_multipart_boundary_from_content_type(
content_type: bytes | None,
) -> bytes | None:
if not content_type or not content_type.startswith(b"multipart/form-data"):
return None
# parse boundary according to
# https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
if b";" in content_type:
for section in content_type.split(b";"):
if section.strip().lower().startswith(b"boundary="):
return section.strip()[len(b"boundary=") :].strip(b'"')
return None
class DataField:
"""
A single form field item, within a multipart form field.
"""
def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
if not isinstance(name, str):
raise TypeError(
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
)
if value is not None and not isinstance(value, (str, bytes, int, float)):
raise TypeError(
"Invalid type for value. Expected primitive type,"
f" got {type(value)}: {value!r}"
)
self.name = name
self.value: str | bytes = (
value if isinstance(value, bytes) else primitive_value_to_str(value)
)
def render_headers(self) -> bytes:
if not hasattr(self, "_headers"):
name = _format_form_param("name", self.name)
self._headers = b"".join(
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
)
return self._headers
def render_data(self) -> bytes:
if not hasattr(self, "_data"):
self._data = to_bytes(self.value)
return self._data
def get_length(self) -> int:
headers = self.render_headers()
data = self.render_data()
return len(headers) + len(data)
def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield self.render_data()
class FileField:
"""
A single file field item, within a multipart form field.
"""
CHUNK_SIZE = 64 * 1024
def __init__(self, name: str, value: FileTypes) -> None:
self.name = name
fileobj: FileContent
headers: dict[str, str] = {}
content_type: str | None = None
# This large tuple based API largely mirror's requests' API
# It would be good to think of better APIs for this that we could
# include in httpx 2.0 since variable length tuples(especially of 4 elements)
# are quite unwieldly
if isinstance(value, tuple):
if len(value) == 2:
# neither the 3rd parameter (content_type) nor the 4th (headers)
# was included
filename, fileobj = value
elif len(value) == 3:
filename, fileobj, content_type = value
else:
# all 4 parameters included
filename, fileobj, content_type, headers = value # type: ignore
else:
filename = Path(str(getattr(value, "name", "upload"))).name
fileobj = value
if content_type is None:
content_type = _guess_content_type(filename)
has_content_type_header = any("content-type" in key.lower() for key in headers)
if content_type is not None and not has_content_type_header:
# note that unlike requests, we ignore the content_type provided in the 3rd
# tuple element if it is also included in the headers requests does
# the opposite (it overwrites the headerwith the 3rd tuple element)
headers["Content-Type"] = content_type
if isinstance(fileobj, io.StringIO):
raise TypeError(
"Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
)
if isinstance(fileobj, io.TextIOBase):
raise TypeError(
"Multipart file uploads must be opened in binary mode, not text mode."
)
self.filename = filename
self.file = fileobj
self.headers = headers
def get_length(self) -> int | None:
headers = self.render_headers()
if isinstance(self.file, (str, bytes)):
return len(headers) + len(to_bytes(self.file))
file_length = peek_filelike_length(self.file)
# If we can't determine the filesize without reading it into memory,
# then return `None` here, to indicate an unknown file length.
if file_length is None:
return None
return len(headers) + file_length
def render_headers(self) -> bytes:
if not hasattr(self, "_headers"):
parts = [
b"Content-Disposition: form-data; ",
_format_form_param("name", self.name),
]
if self.filename:
filename = _format_form_param("filename", self.filename)
parts.extend([b"; ", filename])
for header_name, header_value in self.headers.items():
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
parts.extend([key, val])
parts.append(b"\r\n\r\n")
self._headers = b"".join(parts)
return self._headers
def render_data(self) -> typing.Iterator[bytes]:
if isinstance(self.file, (str, bytes)):
yield to_bytes(self.file)
return
if hasattr(self.file, "seek"):
try:
self.file.seek(0)
except io.UnsupportedOperation:
pass
chunk = self.file.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = self.file.read(self.CHUNK_SIZE)
def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield from self.render_data()
class MultipartStream(SyncByteStream, AsyncByteStream):
"""
Request content as streaming multipart encoded form data.
"""
def __init__(
self,
data: RequestData,
files: RequestFiles,
boundary: bytes | None = None,
) -> None:
if boundary is None:
boundary = os.urandom(16).hex().encode("ascii")
self.boundary = boundary
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
"ascii"
)
self.fields = list(self._iter_fields(data, files))
def _iter_fields(
self, data: RequestData, files: RequestFiles
) -> typing.Iterator[FileField | DataField]:
for name, value in data.items():
if isinstance(value, (tuple, list)):
for item in value:
yield DataField(name=name, value=item)
else:
yield DataField(name=name, value=value)
file_items = files.items() if isinstance(files, typing.Mapping) else files
for name, value in file_items:
yield FileField(name=name, value=value)
def iter_chunks(self) -> typing.Iterator[bytes]:
for field in self.fields:
yield b"--%s\r\n" % self.boundary
yield from field.render()
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary
def get_content_length(self) -> int | None:
"""
Return the length of the multipart encoded content, or `None` if
any of the files have a length that cannot be determined upfront.
"""
boundary_length = len(self.boundary)
length = 0
for field in self.fields:
field_length = field.get_length()
if field_length is None:
return None
length += 2 + boundary_length + 2 # b"--{boundary}\r\n"
length += field_length
length += 2 # b"\r\n"
length += 2 + boundary_length + 4 # b"--{boundary}--\r\n"
return length
# Content stream interface.
def get_headers(self) -> dict[str, str]:
content_length = self.get_content_length()
content_type = self.content_type
if content_length is None:
return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
return {"Content-Length": str(content_length), "Content-Type": content_type}
def __iter__(self) -> typing.Iterator[bytes]:
for chunk in self.iter_chunks():
yield chunk
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
for chunk in self.iter_chunks():
yield chunk

View File

@ -0,0 +1,162 @@
from __future__ import annotations
from enum import IntEnum
__all__ = ["codes"]
class codes(IntEnum):
"""HTTP status codes and reason phrases
Status codes from the following RFCs are all observed:
* RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616
* RFC 6585: Additional HTTP Status Codes
* RFC 3229: Delta encoding in HTTP
* RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518
* RFC 5842: Binding Extensions to WebDAV
* RFC 7238: Permanent Redirect
* RFC 2295: Transparent Content Negotiation in HTTP
* RFC 2774: An HTTP Extension Framework
* RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2)
* RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0)
* RFC 7725: An HTTP Status Code to Report Legal Obstacles
* RFC 8297: An HTTP Status Code for Indicating Hints
* RFC 8470: Using Early Data in HTTP
"""
def __new__(cls, value: int, phrase: str = "") -> codes:
obj = int.__new__(cls, value)
obj._value_ = value
obj.phrase = phrase # type: ignore[attr-defined]
return obj
def __str__(self) -> str:
return str(self.value)
@classmethod
def get_reason_phrase(cls, value: int) -> str:
try:
return codes(value).phrase # type: ignore
except ValueError:
return ""
@classmethod
def is_informational(cls, value: int) -> bool:
"""
Returns `True` for 1xx status codes, `False` otherwise.
"""
return 100 <= value <= 199
@classmethod
def is_success(cls, value: int) -> bool:
"""
Returns `True` for 2xx status codes, `False` otherwise.
"""
return 200 <= value <= 299
@classmethod
def is_redirect(cls, value: int) -> bool:
"""
Returns `True` for 3xx status codes, `False` otherwise.
"""
return 300 <= value <= 399
@classmethod
def is_client_error(cls, value: int) -> bool:
"""
Returns `True` for 4xx status codes, `False` otherwise.
"""
return 400 <= value <= 499
@classmethod
def is_server_error(cls, value: int) -> bool:
"""
Returns `True` for 5xx status codes, `False` otherwise.
"""
return 500 <= value <= 599
@classmethod
def is_error(cls, value: int) -> bool:
"""
Returns `True` for 4xx or 5xx status codes, `False` otherwise.
"""
return 400 <= value <= 599
# informational
CONTINUE = 100, "Continue"
SWITCHING_PROTOCOLS = 101, "Switching Protocols"
PROCESSING = 102, "Processing"
EARLY_HINTS = 103, "Early Hints"
# success
OK = 200, "OK"
CREATED = 201, "Created"
ACCEPTED = 202, "Accepted"
NON_AUTHORITATIVE_INFORMATION = 203, "Non-Authoritative Information"
NO_CONTENT = 204, "No Content"
RESET_CONTENT = 205, "Reset Content"
PARTIAL_CONTENT = 206, "Partial Content"
MULTI_STATUS = 207, "Multi-Status"
ALREADY_REPORTED = 208, "Already Reported"
IM_USED = 226, "IM Used"
# redirection
MULTIPLE_CHOICES = 300, "Multiple Choices"
MOVED_PERMANENTLY = 301, "Moved Permanently"
FOUND = 302, "Found"
SEE_OTHER = 303, "See Other"
NOT_MODIFIED = 304, "Not Modified"
USE_PROXY = 305, "Use Proxy"
TEMPORARY_REDIRECT = 307, "Temporary Redirect"
PERMANENT_REDIRECT = 308, "Permanent Redirect"
# client error
BAD_REQUEST = 400, "Bad Request"
UNAUTHORIZED = 401, "Unauthorized"
PAYMENT_REQUIRED = 402, "Payment Required"
FORBIDDEN = 403, "Forbidden"
NOT_FOUND = 404, "Not Found"
METHOD_NOT_ALLOWED = 405, "Method Not Allowed"
NOT_ACCEPTABLE = 406, "Not Acceptable"
PROXY_AUTHENTICATION_REQUIRED = 407, "Proxy Authentication Required"
REQUEST_TIMEOUT = 408, "Request Timeout"
CONFLICT = 409, "Conflict"
GONE = 410, "Gone"
LENGTH_REQUIRED = 411, "Length Required"
PRECONDITION_FAILED = 412, "Precondition Failed"
REQUEST_ENTITY_TOO_LARGE = 413, "Request Entity Too Large"
REQUEST_URI_TOO_LONG = 414, "Request-URI Too Long"
UNSUPPORTED_MEDIA_TYPE = 415, "Unsupported Media Type"
REQUESTED_RANGE_NOT_SATISFIABLE = 416, "Requested Range Not Satisfiable"
EXPECTATION_FAILED = 417, "Expectation Failed"
IM_A_TEAPOT = 418, "I'm a teapot"
MISDIRECTED_REQUEST = 421, "Misdirected Request"
UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity"
LOCKED = 423, "Locked"
FAILED_DEPENDENCY = 424, "Failed Dependency"
TOO_EARLY = 425, "Too Early"
UPGRADE_REQUIRED = 426, "Upgrade Required"
PRECONDITION_REQUIRED = 428, "Precondition Required"
TOO_MANY_REQUESTS = 429, "Too Many Requests"
REQUEST_HEADER_FIELDS_TOO_LARGE = 431, "Request Header Fields Too Large"
UNAVAILABLE_FOR_LEGAL_REASONS = 451, "Unavailable For Legal Reasons"
# server errors
INTERNAL_SERVER_ERROR = 500, "Internal Server Error"
NOT_IMPLEMENTED = 501, "Not Implemented"
BAD_GATEWAY = 502, "Bad Gateway"
SERVICE_UNAVAILABLE = 503, "Service Unavailable"
GATEWAY_TIMEOUT = 504, "Gateway Timeout"
HTTP_VERSION_NOT_SUPPORTED = 505, "HTTP Version Not Supported"
VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates"
INSUFFICIENT_STORAGE = 507, "Insufficient Storage"
LOOP_DETECTED = 508, "Loop Detected"
NOT_EXTENDED = 510, "Not Extended"
NETWORK_AUTHENTICATION_REQUIRED = 511, "Network Authentication Required"
# Include lower-case styles for `requests` compatibility.
for code in codes:
setattr(codes, code._name_.lower(), int(code))

View File

@ -0,0 +1,15 @@
from .asgi import *
from .base import *
from .default import *
from .mock import *
from .wsgi import *
__all__ = [
"ASGITransport",
"AsyncBaseTransport",
"BaseTransport",
"AsyncHTTPTransport",
"HTTPTransport",
"MockTransport",
"WSGITransport",
]

View File

@ -0,0 +1,187 @@
from __future__ import annotations
import typing
from .._models import Request, Response
from .._types import AsyncByteStream
from .base import AsyncBaseTransport
if typing.TYPE_CHECKING: # pragma: no cover
import asyncio
import trio
Event = typing.Union[asyncio.Event, trio.Event]
_Message = typing.MutableMapping[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[
[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
]
_ASGIApp = typing.Callable[
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
]
__all__ = ["ASGITransport"]
def is_running_trio() -> bool:
try:
# sniffio is a dependency of trio.
# See https://github.com/python-trio/trio/issues/2802
import sniffio
if sniffio.current_async_library() == "trio":
return True
except ImportError: # pragma: nocover
pass
return False
def create_event() -> Event:
if is_running_trio():
import trio
return trio.Event()
import asyncio
return asyncio.Event()
class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: list[bytes]) -> None:
self._body = body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b"".join(self._body)
class ASGITransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
```python
transport = httpx.ASGITransport(
app=app,
root_path="/submount",
client=("1.2.3.4", 123)
)
client = httpx.AsyncClient(transport=transport)
```
Arguments:
* `app` - The ASGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
* `root_path` - The root path on which the ASGI application should be mounted.
* `client` - A two-tuple indicating the client IP and port of incoming requests.
```
"""
def __init__(
self,
app: _ASGIApp,
raise_app_exceptions: bool = True,
root_path: str = "",
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}
# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False
# Response.
status_code = None
response_headers = None
body_parts = []
response_started = False
response_complete = create_event()
# ASGI callables.
async def receive() -> dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":
assert not response_started
status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and request.method != "HEAD":
body_parts.append(body)
if not more_body:
response_complete.set()
try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions:
raise
response_complete.set()
if status_code is None:
status_code = 500
if response_headers is None:
response_headers = {}
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None
stream = ASGIResponseStream(body_parts)
return Response(status_code, headers=response_headers, stream=stream)

View File

@ -0,0 +1,86 @@
from __future__ import annotations
import typing
from types import TracebackType
from .._models import Request, Response
T = typing.TypeVar("T", bound="BaseTransport")
A = typing.TypeVar("A", bound="AsyncBaseTransport")
__all__ = ["AsyncBaseTransport", "BaseTransport"]
class BaseTransport:
def __enter__(self: T) -> T:
return self
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
self.close()
def handle_request(self, request: Request) -> Response:
"""
Send a single HTTP request and return a response.
Developers shouldn't typically ever need to call into this API directly,
since the Client class provides all the higher level user-facing API
niceties.
In order to properly release any network resources, the response
stream should *either* be consumed immediately, with a call to
`response.stream.read()`, or else the `handle_request` call should
be followed with a try/finally block to ensuring the stream is
always closed.
Example usage:
with httpx.HTTPTransport() as transport:
req = httpx.Request(
method=b"GET",
url=(b"https", b"www.example.com", 443, b"/"),
headers=[(b"Host", b"www.example.com")],
)
resp = transport.handle_request(req)
body = resp.stream.read()
print(resp.status_code, resp.headers, body)
Takes a `Request` instance as the only argument.
Returns a `Response` instance.
"""
raise NotImplementedError(
"The 'handle_request' method must be implemented."
) # pragma: no cover
def close(self) -> None:
pass
class AsyncBaseTransport:
async def __aenter__(self: A) -> A:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self.aclose()
async def handle_async_request(
self,
request: Request,
) -> Response:
raise NotImplementedError(
"The 'handle_async_request' method must be implemented."
) # pragma: no cover
async def aclose(self) -> None:
pass

View File

@ -0,0 +1,406 @@
"""
Custom transports, with nicely configured defaults.
The following additional keyword arguments are currently supported by httpcore...
* uds: str
* local_address: str
* retries: int
Example usages...
# Disable HTTP/2 on a single specific domain.
mounts = {
"all://": httpx.HTTPTransport(http2=True),
"all://*example.org": httpx.HTTPTransport()
}
# Using advanced httpcore configuration, with connection retries.
transport = httpx.HTTPTransport(retries=1)
client = httpx.Client(transport=transport)
# Using advanced httpcore configuration, with unix domain sockets.
transport = httpx.HTTPTransport(uds="socket.uds")
client = httpx.Client(transport=transport)
"""
from __future__ import annotations
import contextlib
import typing
from types import TracebackType
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
import httpx # pragma: no cover
from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
from .._exceptions import (
ConnectError,
ConnectTimeout,
LocalProtocolError,
NetworkError,
PoolTimeout,
ProtocolError,
ProxyError,
ReadError,
ReadTimeout,
RemoteProtocolError,
TimeoutException,
UnsupportedProtocol,
WriteError,
WriteTimeout,
)
from .._models import Request, Response
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
from .._urls import URL
from .base import AsyncBaseTransport, BaseTransport
T = typing.TypeVar("T", bound="HTTPTransport")
A = typing.TypeVar("A", bound="AsyncHTTPTransport")
SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
typing.Tuple[int, int, None, int],
]
__all__ = ["AsyncHTTPTransport", "HTTPTransport"]
HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {}
def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]:
import httpcore
return {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
@contextlib.contextmanager
def map_httpcore_exceptions() -> typing.Iterator[None]:
global HTTPCORE_EXC_MAP
if len(HTTPCORE_EXC_MAP) == 0:
HTTPCORE_EXC_MAP = _load_httpcore_exceptions()
try:
yield
except Exception as exc:
mapped_exc = None
for from_exc, to_exc in HTTPCORE_EXC_MAP.items():
if not isinstance(exc, from_exc):
continue
# We want to map to the most specific exception we can find.
# Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
# `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc
if mapped_exc is None: # pragma: no cover
raise
message = str(exc)
raise mapped_exc(message) from exc
class ResponseStream(SyncByteStream):
def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
def __iter__(self) -> typing.Iterator[bytes]:
with map_httpcore_exceptions():
for part in self._httpcore_stream:
yield part
def close(self) -> None:
if hasattr(self._httpcore_stream, "close"):
self._httpcore_stream.close()
class HTTPTransport(BaseTransport):
def __init__(
self,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http1: bool = True,
http2: bool = False,
limits: Limits = DEFAULT_LIMITS,
proxy: ProxyTypes | None = None,
uds: str | None = None,
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
import httpcore
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
if proxy is None:
self._pool = httpcore.ConnectionPool(
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.HTTPProxy(
proxy_url=httpcore.URL(
scheme=proxy.url.raw_scheme,
host=proxy.url.raw_host,
port=proxy.url.port,
target=proxy.url.raw_path,
),
proxy_auth=proxy.raw_auth,
proxy_headers=proxy.headers.raw,
ssl_context=ssl_context,
proxy_ssl_context=proxy.ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
)
elif proxy.url.scheme in ("socks5", "socks5h"):
try:
import socksio # noqa
except ImportError: # pragma: no cover
raise ImportError(
"Using SOCKS proxy, but the 'socksio' package is not installed. "
"Make sure to install httpx using `pip install httpx[socks]`."
) from None
self._pool = httpcore.SOCKSProxy(
proxy_url=httpcore.URL(
scheme=proxy.url.raw_scheme,
host=proxy.url.raw_host,
port=proxy.url.port,
target=proxy.url.raw_path,
),
proxy_auth=proxy.raw_auth,
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
)
else: # pragma: no cover
raise ValueError(
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
f" but got {proxy.url.scheme!r}."
)
def __enter__(self: T) -> T: # Use generics for subclass support.
self._pool.__enter__()
return self
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
self._pool.__exit__(exc_type, exc_value, traceback)
def handle_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, SyncByteStream)
import httpcore
req = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
with map_httpcore_exceptions():
resp = self._pool.handle_request(req)
assert isinstance(resp.stream, typing.Iterable)
return Response(
status_code=resp.status,
headers=resp.headers,
stream=ResponseStream(resp.stream),
extensions=resp.extensions,
)
def close(self) -> None:
self._pool.close()
class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
with map_httpcore_exceptions():
async for part in self._httpcore_stream:
yield part
async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):
await self._httpcore_stream.aclose()
class AsyncHTTPTransport(AsyncBaseTransport):
def __init__(
self,
verify: ssl.SSLContext | str | bool = True,
cert: CertTypes | None = None,
trust_env: bool = True,
http1: bool = True,
http2: bool = False,
limits: Limits = DEFAULT_LIMITS,
proxy: ProxyTypes | None = None,
uds: str | None = None,
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
import httpcore
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
if proxy is None:
self._pool = httpcore.AsyncConnectionPool(
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.AsyncHTTPProxy(
proxy_url=httpcore.URL(
scheme=proxy.url.raw_scheme,
host=proxy.url.raw_host,
port=proxy.url.port,
target=proxy.url.raw_path,
),
proxy_auth=proxy.raw_auth,
proxy_headers=proxy.headers.raw,
proxy_ssl_context=proxy.ssl_context,
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
)
elif proxy.url.scheme in ("socks5", "socks5h"):
try:
import socksio # noqa
except ImportError: # pragma: no cover
raise ImportError(
"Using SOCKS proxy, but the 'socksio' package is not installed. "
"Make sure to install httpx using `pip install httpx[socks]`."
) from None
self._pool = httpcore.AsyncSOCKSProxy(
proxy_url=httpcore.URL(
scheme=proxy.url.raw_scheme,
host=proxy.url.raw_host,
port=proxy.url.port,
target=proxy.url.raw_path,
),
proxy_auth=proxy.raw_auth,
ssl_context=ssl_context,
max_connections=limits.max_connections,
max_keepalive_connections=limits.max_keepalive_connections,
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
)
else: # pragma: no cover
raise ValueError(
"Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h',"
" but got {proxy.url.scheme!r}."
)
async def __aenter__(self: A) -> A: # Use generics for subclass support.
await self._pool.__aenter__()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
await self._pool.__aexit__(exc_type, exc_value, traceback)
async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
import httpcore
req = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
with map_httpcore_exceptions():
resp = await self._pool.handle_async_request(req)
assert isinstance(resp.stream, typing.AsyncIterable)
return Response(
status_code=resp.status,
headers=resp.headers,
stream=AsyncResponseStream(resp.stream),
extensions=resp.extensions,
)
async def aclose(self) -> None:
await self._pool.aclose()

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import typing
from .._models import Request, Response
from .base import AsyncBaseTransport, BaseTransport
SyncHandler = typing.Callable[[Request], Response]
AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]]
__all__ = ["MockTransport"]
class MockTransport(AsyncBaseTransport, BaseTransport):
def __init__(self, handler: SyncHandler | AsyncHandler) -> None:
self.handler = handler
def handle_request(
self,
request: Request,
) -> Response:
request.read()
response = self.handler(request)
if not isinstance(response, Response): # pragma: no cover
raise TypeError("Cannot use an async handler in a sync Client")
return response
async def handle_async_request(
self,
request: Request,
) -> Response:
await request.aread()
response = self.handler(request)
# Allow handler to *optionally* be an `async` function.
# If it is, then the `response` variable need to be awaited to actually
# return the result.
if not isinstance(response, Response):
response = await response
return response

View File

@ -0,0 +1,149 @@
from __future__ import annotations
import io
import itertools
import sys
import typing
from .._models import Request, Response
from .._types import SyncByteStream
from .base import BaseTransport
if typing.TYPE_CHECKING:
from _typeshed import OptExcInfo # pragma: no cover
from _typeshed.wsgi import WSGIApplication # pragma: no cover
_T = typing.TypeVar("_T")
__all__ = ["WSGITransport"]
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
body = iter(body)
for chunk in body:
if chunk:
return itertools.chain([chunk], body)
return []
class WSGIByteStream(SyncByteStream):
def __init__(self, result: typing.Iterable[bytes]) -> None:
self._close = getattr(result, "close", None)
self._result = _skip_leading_empty_chunks(result)
def __iter__(self) -> typing.Iterator[bytes]:
for part in self._result:
yield part
def close(self) -> None:
if self._close is not None:
self._close()
class WSGITransport(BaseTransport):
"""
A custom transport that handles sending requests directly to an WSGI app.
The simplest way to use this functionality is to use the `app` argument.
```
client = httpx.Client(app=app)
```
Alternatively, you can setup the transport instance explicitly.
This allows you to include any additional configuration arguments specific
to the WSGITransport class:
```
transport = httpx.WSGITransport(
app=app,
script_name="/submount",
remote_addr="1.2.3.4"
)
client = httpx.Client(transport=transport)
```
Arguments:
* `app` - The WSGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
* `script_name` - The root path on which the WSGI application should be mounted.
* `remote_addr` - A string indicating the client IP of incoming requests.
```
"""
def __init__(
self,
app: WSGIApplication,
raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
wsgi_errors: typing.TextIO | None = None,
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.script_name = script_name
self.remote_addr = remote_addr
self.wsgi_errors = wsgi_errors
def handle_request(self, request: Request) -> Response:
request.read()
wsgi_input = io.BytesIO(request.content)
port = request.url.port or {"http": 80, "https": 443}[request.url.scheme]
environ = {
"wsgi.version": (1, 0),
"wsgi.url_scheme": request.url.scheme,
"wsgi.input": wsgi_input,
"wsgi.errors": self.wsgi_errors or sys.stderr,
"wsgi.multithread": True,
"wsgi.multiprocess": False,
"wsgi.run_once": False,
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": self.script_name,
"PATH_INFO": request.url.path,
"QUERY_STRING": request.url.query.decode("ascii"),
"SERVER_NAME": request.url.host,
"SERVER_PORT": str(port),
"SERVER_PROTOCOL": "HTTP/1.1",
"REMOTE_ADDR": self.remote_addr,
}
for header_key, header_value in request.headers.raw:
key = header_key.decode("ascii").upper().replace("-", "_")
if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
key = "HTTP_" + key
environ[key] = header_value.decode("ascii")
seen_status = None
seen_response_headers = None
seen_exc_info = None
def start_response(
status: str,
response_headers: list[tuple[str, str]],
exc_info: OptExcInfo | None = None,
) -> typing.Callable[[bytes], typing.Any]:
nonlocal seen_status, seen_response_headers, seen_exc_info
seen_status = status
seen_response_headers = response_headers
seen_exc_info = exc_info
return lambda _: None
result = self.app(environ, start_response)
stream = WSGIByteStream(result)
assert seen_status is not None
assert seen_response_headers is not None
if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
raise seen_exc_info[1]
status_code = int(seen_status.split()[0])
headers = [
(key.encode("ascii"), value.encode("ascii"))
for key, value in seen_response_headers
]
return Response(status_code, headers=headers, stream=stream)

View File

@ -0,0 +1,114 @@
"""
Type definitions for type checking purposes.
"""
from http.cookiejar import CookieJar
from typing import (
IO,
TYPE_CHECKING,
Any,
AsyncIterable,
AsyncIterator,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
if TYPE_CHECKING: # pragma: no cover
from ._auth import Auth # noqa: F401
from ._config import Proxy, Timeout # noqa: F401
from ._models import Cookies, Headers, Request # noqa: F401
from ._urls import URL, QueryParams # noqa: F401
PrimitiveData = Optional[Union[str, int, float, bool]]
URLTypes = Union["URL", str]
QueryParamTypes = Union[
"QueryParams",
Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]],
List[Tuple[str, PrimitiveData]],
Tuple[Tuple[str, PrimitiveData], ...],
str,
bytes,
]
HeaderTypes = Union[
"Headers",
Mapping[str, str],
Mapping[bytes, bytes],
Sequence[Tuple[str, str]],
Sequence[Tuple[bytes, bytes]],
]
CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
TimeoutTypes = Union[
Optional[float],
Tuple[Optional[float], Optional[float], Optional[float], Optional[float]],
"Timeout",
]
ProxyTypes = Union["URL", str, "Proxy"]
CertTypes = Union[str, Tuple[str, str], Tuple[str, str, str]]
AuthTypes = Union[
Tuple[Union[str, bytes], Union[str, bytes]],
Callable[["Request"], "Request"],
"Auth",
]
RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
ResponseExtensions = Mapping[str, Any]
RequestData = Mapping[str, Any]
FileContent = Union[IO[bytes], bytes, str]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
RequestExtensions = Mapping[str, Any]
__all__ = ["AsyncByteStream", "SyncByteStream"]
class SyncByteStream:
def __iter__(self) -> Iterator[bytes]:
raise NotImplementedError(
"The '__iter__' method must be implemented."
) # pragma: no cover
yield b"" # pragma: no cover
def close(self) -> None:
"""
Subclasses can override this method to release any network resources
after a request/response cycle is complete.
"""
class AsyncByteStream:
async def __aiter__(self) -> AsyncIterator[bytes]:
raise NotImplementedError(
"The '__aiter__' method must be implemented."
) # pragma: no cover
yield b"" # pragma: no cover
async def aclose(self) -> None:
pass

View File

@ -0,0 +1,527 @@
"""
An implementation of `urlparse` that provides URL validation and normalization
as described by RFC3986.
We rely on this implementation rather than the one in Python's stdlib, because:
* It provides more complete URL validation.
* It properly differentiates between an empty querystring and an absent querystring,
to distinguish URLs with a trailing '?'.
* It handles scheme, hostname, port, and path normalization.
* It supports IDNA hostnames, normalizing them to their encoded form.
* The API supports passing individual components, as well as the complete URL string.
Previously we relied on the excellent `rfc3986` package to handle URL parsing and
validation, but this module provides a simpler alternative, with less indirection
required.
"""
from __future__ import annotations
import ipaddress
import re
import typing
import idna
from ._exceptions import InvalidURL
MAX_URL_LENGTH = 65536
# https://datatracker.ietf.org/doc/html/rfc3986.html#section-2.3
UNRESERVED_CHARACTERS = (
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
)
SUB_DELIMS = "!$&'()*+,;="
PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}")
# https://url.spec.whatwg.org/#percent-encoded-bytes
# The fragment percent-encode set is the C0 control percent-encode set
# and U+0020 SPACE, U+0022 ("), U+003C (<), U+003E (>), and U+0060 (`).
FRAG_SAFE = "".join(
[chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x3C, 0x3E, 0x60)]
)
# The query percent-encode set is the C0 control percent-encode set
# and U+0020 SPACE, U+0022 ("), U+0023 (#), U+003C (<), and U+003E (>).
QUERY_SAFE = "".join(
[chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E)]
)
# The path percent-encode set is the query percent-encode set
# and U+003F (?), U+0060 (`), U+007B ({), and U+007D (}).
PATH_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + (0x3F, 0x60, 0x7B, 0x7D)
]
)
# The userinfo percent-encode set is the path percent-encode set
# and U+002F (/), U+003A (:), U+003B (;), U+003D (=), U+0040 (@),
# U+005B ([) to U+005E (^), inclusive, and U+007C (|).
USERNAME_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
+ (0x3F, 0x60, 0x7B, 0x7D)
+ (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
]
)
PASSWORD_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
+ (0x3F, 0x60, 0x7B, 0x7D)
+ (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
]
)
# Note... The terminology 'userinfo' percent-encode set in the WHATWG document
# is used for the username and password quoting. For the joint userinfo component
# we remove U+003A (:) from the safe set.
USERINFO_SAFE = "".join(
[
chr(i)
for i in range(0x20, 0x7F)
if i
not in (0x20, 0x22, 0x23, 0x3C, 0x3E)
+ (0x3F, 0x60, 0x7B, 0x7D)
+ (0x2F, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C)
]
)
# {scheme}: (optional)
# //{authority} (optional)
# {path}
# ?{query} (optional)
# #{fragment} (optional)
URL_REGEX = re.compile(
(
r"(?:(?P<scheme>{scheme}):)?"
r"(?://(?P<authority>{authority}))?"
r"(?P<path>{path})"
r"(?:\?(?P<query>{query}))?"
r"(?:#(?P<fragment>{fragment}))?"
).format(
scheme="([a-zA-Z][a-zA-Z0-9+.-]*)?",
authority="[^/?#]*",
path="[^?#]*",
query="[^#]*",
fragment=".*",
)
)
# {userinfo}@ (optional)
# {host}
# :{port} (optional)
AUTHORITY_REGEX = re.compile(
(
r"(?:(?P<userinfo>{userinfo})@)?" r"(?P<host>{host})" r":?(?P<port>{port})?"
).format(
userinfo=".*", # Any character sequence.
host="(\\[.*\\]|[^:@]*)", # Either any character sequence excluding ':' or '@',
# or an IPv6 address enclosed within square brackets.
port=".*", # Any character sequence.
)
)
# If we call urlparse with an individual component, then we need to regex
# validate that component individually.
# Note that we're duplicating the same strings as above. Shock! Horror!!
COMPONENT_REGEX = {
"scheme": re.compile("([a-zA-Z][a-zA-Z0-9+.-]*)?"),
"authority": re.compile("[^/?#]*"),
"path": re.compile("[^?#]*"),
"query": re.compile("[^#]*"),
"fragment": re.compile(".*"),
"userinfo": re.compile("[^@]*"),
"host": re.compile("(\\[.*\\]|[^:]*)"),
"port": re.compile(".*"),
}
# We use these simple regexs as a first pass before handing off to
# the stdlib 'ipaddress' module for IP address validation.
IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$")
IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$")
class ParseResult(typing.NamedTuple):
scheme: str
userinfo: str
host: str
port: int | None
path: str
query: str | None
fragment: str | None
@property
def authority(self) -> str:
return "".join(
[
f"{self.userinfo}@" if self.userinfo else "",
f"[{self.host}]" if ":" in self.host else self.host,
f":{self.port}" if self.port is not None else "",
]
)
@property
def netloc(self) -> str:
return "".join(
[
f"[{self.host}]" if ":" in self.host else self.host,
f":{self.port}" if self.port is not None else "",
]
)
def copy_with(self, **kwargs: str | None) -> ParseResult:
if not kwargs:
return self
defaults = {
"scheme": self.scheme,
"authority": self.authority,
"path": self.path,
"query": self.query,
"fragment": self.fragment,
}
defaults.update(kwargs)
return urlparse("", **defaults)
def __str__(self) -> str:
authority = self.authority
return "".join(
[
f"{self.scheme}:" if self.scheme else "",
f"//{authority}" if authority else "",
self.path,
f"?{self.query}" if self.query is not None else "",
f"#{self.fragment}" if self.fragment is not None else "",
]
)
def urlparse(url: str = "", **kwargs: str | None) -> ParseResult:
# Initial basic checks on allowable URLs.
# ---------------------------------------
# Hard limit the maximum allowable URL length.
if len(url) > MAX_URL_LENGTH:
raise InvalidURL("URL too long")
# If a URL includes any ASCII control characters including \t, \r, \n,
# then treat it as invalid.
if any(char.isascii() and not char.isprintable() for char in url):
char = next(char for char in url if char.isascii() and not char.isprintable())
idx = url.find(char)
error = (
f"Invalid non-printable ASCII character in URL, {char!r} at position {idx}."
)
raise InvalidURL(error)
# Some keyword arguments require special handling.
# ------------------------------------------------
# Coerce "port" to a string, if it is provided as an integer.
if "port" in kwargs:
port = kwargs["port"]
kwargs["port"] = str(port) if isinstance(port, int) else port
# Replace "netloc" with "host and "port".
if "netloc" in kwargs:
netloc = kwargs.pop("netloc") or ""
kwargs["host"], _, kwargs["port"] = netloc.partition(":")
# Replace "username" and/or "password" with "userinfo".
if "username" in kwargs or "password" in kwargs:
username = quote(kwargs.pop("username", "") or "", safe=USERNAME_SAFE)
password = quote(kwargs.pop("password", "") or "", safe=PASSWORD_SAFE)
kwargs["userinfo"] = f"{username}:{password}" if password else username
# Replace "raw_path" with "path" and "query".
if "raw_path" in kwargs:
raw_path = kwargs.pop("raw_path") or ""
kwargs["path"], seperator, kwargs["query"] = raw_path.partition("?")
if not seperator:
kwargs["query"] = None
# Ensure that IPv6 "host" addresses are always escaped with "[...]".
if "host" in kwargs:
host = kwargs.get("host") or ""
if ":" in host and not (host.startswith("[") and host.endswith("]")):
kwargs["host"] = f"[{host}]"
# If any keyword arguments are provided, ensure they are valid.
# -------------------------------------------------------------
for key, value in kwargs.items():
if value is not None:
if len(value) > MAX_URL_LENGTH:
raise InvalidURL(f"URL component '{key}' too long")
# If a component includes any ASCII control characters including \t, \r, \n,
# then treat it as invalid.
if any(char.isascii() and not char.isprintable() for char in value):
char = next(
char for char in value if char.isascii() and not char.isprintable()
)
idx = value.find(char)
error = (
f"Invalid non-printable ASCII character in URL {key} component, "
f"{char!r} at position {idx}."
)
raise InvalidURL(error)
# Ensure that keyword arguments match as a valid regex.
if not COMPONENT_REGEX[key].fullmatch(value):
raise InvalidURL(f"Invalid URL component '{key}'")
# The URL_REGEX will always match, but may have empty components.
url_match = URL_REGEX.match(url)
assert url_match is not None
url_dict = url_match.groupdict()
# * 'scheme', 'authority', and 'path' may be empty strings.
# * 'query' may be 'None', indicating no trailing "?" portion.
# Any string including the empty string, indicates a trailing "?".
# * 'fragment' may be 'None', indicating no trailing "#" portion.
# Any string including the empty string, indicates a trailing "#".
scheme = kwargs.get("scheme", url_dict["scheme"]) or ""
authority = kwargs.get("authority", url_dict["authority"]) or ""
path = kwargs.get("path", url_dict["path"]) or ""
query = kwargs.get("query", url_dict["query"])
frag = kwargs.get("fragment", url_dict["fragment"])
# The AUTHORITY_REGEX will always match, but may have empty components.
authority_match = AUTHORITY_REGEX.match(authority)
assert authority_match is not None
authority_dict = authority_match.groupdict()
# * 'userinfo' and 'host' may be empty strings.
# * 'port' may be 'None'.
userinfo = kwargs.get("userinfo", authority_dict["userinfo"]) or ""
host = kwargs.get("host", authority_dict["host"]) or ""
port = kwargs.get("port", authority_dict["port"])
# Normalize and validate each component.
# We end up with a parsed representation of the URL,
# with components that are plain ASCII bytestrings.
parsed_scheme: str = scheme.lower()
parsed_userinfo: str = quote(userinfo, safe=USERINFO_SAFE)
parsed_host: str = encode_host(host)
parsed_port: int | None = normalize_port(port, scheme)
has_scheme = parsed_scheme != ""
has_authority = (
parsed_userinfo != "" or parsed_host != "" or parsed_port is not None
)
validate_path(path, has_scheme=has_scheme, has_authority=has_authority)
if has_scheme or has_authority:
path = normalize_path(path)
parsed_path: str = quote(path, safe=PATH_SAFE)
parsed_query: str | None = None if query is None else quote(query, safe=QUERY_SAFE)
parsed_frag: str | None = None if frag is None else quote(frag, safe=FRAG_SAFE)
# The parsed ASCII bytestrings are our canonical form.
# All properties of the URL are derived from these.
return ParseResult(
parsed_scheme,
parsed_userinfo,
parsed_host,
parsed_port,
parsed_path,
parsed_query,
parsed_frag,
)
def encode_host(host: str) -> str:
if not host:
return ""
elif IPv4_STYLE_HOSTNAME.match(host):
# Validate IPv4 hostnames like #.#.#.#
#
# From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2
#
# IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet
try:
ipaddress.IPv4Address(host)
except ipaddress.AddressValueError:
raise InvalidURL(f"Invalid IPv4 address: {host!r}")
return host
elif IPv6_STYLE_HOSTNAME.match(host):
# Validate IPv6 hostnames like [...]
#
# From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2
#
# "A host identified by an Internet Protocol literal address, version 6
# [RFC3513] or later, is distinguished by enclosing the IP literal
# within square brackets ("[" and "]"). This is the only place where
# square bracket characters are allowed in the URI syntax."
try:
ipaddress.IPv6Address(host[1:-1])
except ipaddress.AddressValueError:
raise InvalidURL(f"Invalid IPv6 address: {host!r}")
return host[1:-1]
elif host.isascii():
# Regular ASCII hostnames
#
# From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2
#
# reg-name = *( unreserved / pct-encoded / sub-delims )
WHATWG_SAFE = '"`{}%|\\'
return quote(host.lower(), safe=SUB_DELIMS + WHATWG_SAFE)
# IDNA hostnames
try:
return idna.encode(host.lower()).decode("ascii")
except idna.IDNAError:
raise InvalidURL(f"Invalid IDNA hostname: {host!r}")
def normalize_port(port: str | int | None, scheme: str) -> int | None:
# From https://tools.ietf.org/html/rfc3986#section-3.2.3
#
# "A scheme may define a default port. For example, the "http" scheme
# defines a default port of "80", corresponding to its reserved TCP
# port number. The type of port designated by the port number (e.g.,
# TCP, UDP, SCTP) is defined by the URI scheme. URI producers and
# normalizers should omit the port component and its ":" delimiter if
# port is empty or if its value would be the same as that of the
# scheme's default."
if port is None or port == "":
return None
try:
port_as_int = int(port)
except ValueError:
raise InvalidURL(f"Invalid port: {port!r}")
# See https://url.spec.whatwg.org/#url-miscellaneous
default_port = {"ftp": 21, "http": 80, "https": 443, "ws": 80, "wss": 443}.get(
scheme
)
if port_as_int == default_port:
return None
return port_as_int
def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None:
"""
Path validation rules that depend on if the URL contains
a scheme or authority component.
See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3
"""
if has_authority:
# If a URI contains an authority component, then the path component
# must either be empty or begin with a slash ("/") character."
if path and not path.startswith("/"):
raise InvalidURL("For absolute URLs, path must be empty or begin with '/'")
if not has_scheme and not has_authority:
# If a URI does not contain an authority component, then the path cannot begin
# with two slash characters ("//").
if path.startswith("//"):
raise InvalidURL("Relative URLs cannot have a path starting with '//'")
# In addition, a URI reference (Section 4.1) may be a relative-path reference,
# in which case the first path segment cannot contain a colon (":") character.
if path.startswith(":"):
raise InvalidURL("Relative URLs cannot have a path starting with ':'")
def normalize_path(path: str) -> str:
"""
Drop "." and ".." segments from a URL path.
For example:
normalize_path("/path/./to/somewhere/..") == "/path/to"
"""
# Fast return when no '.' characters in the path.
if "." not in path:
return path
components = path.split("/")
# Fast return when no '.' or '..' components in the path.
if "." not in components and ".." not in components:
return path
# https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
output: list[str] = []
for component in components:
if component == ".":
pass
elif component == "..":
if output and output != [""]:
output.pop()
else:
output.append(component)
return "/".join(output)
def PERCENT(string: str) -> str:
return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")])
def percent_encoded(string: str, safe: str) -> str:
"""
Use percent-encoding to quote a string.
"""
NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe
# Fast path for strings that don't need escaping.
if not string.rstrip(NON_ESCAPED_CHARS):
return string
return "".join(
[char if char in NON_ESCAPED_CHARS else PERCENT(char) for char in string]
)
def quote(string: str, safe: str) -> str:
"""
Use percent-encoding to quote a string, omitting existing '%xx' escape sequences.
See: https://www.rfc-editor.org/rfc/rfc3986#section-2.1
* `string`: The string to be percent-escaped.
* `safe`: A string containing characters that may be treated as safe, and do not
need to be escaped. Unreserved characters are always treated as safe.
See: https://www.rfc-editor.org/rfc/rfc3986#section-2.3
"""
parts = []
current_position = 0
for match in re.finditer(PERCENT_ENCODED_REGEX, string):
start_position, end_position = match.start(), match.end()
matched_text = match.group(0)
# Add any text up to the '%xx' escape sequence.
if start_position != current_position:
leading_text = string[current_position:start_position]
parts.append(percent_encoded(leading_text, safe=safe))
# Add the '%xx' escape sequence.
parts.append(matched_text)
current_position = end_position
# Add any text after the final '%xx' escape sequence.
if current_position != len(string):
trailing_text = string[current_position:]
parts.append(percent_encoded(trailing_text, safe=safe))
return "".join(parts)

View File

@ -0,0 +1,641 @@
from __future__ import annotations
import typing
from urllib.parse import parse_qs, unquote, urlencode
import idna
from ._types import QueryParamTypes
from ._urlparse import urlparse
from ._utils import primitive_value_to_str
__all__ = ["URL", "QueryParams"]
class URL:
"""
url = httpx.URL("HTTPS://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink")
assert url.scheme == "https"
assert url.username == "jo@email.com"
assert url.password == "a secret"
assert url.userinfo == b"jo%40email.com:a%20secret"
assert url.host == "müller.de"
assert url.raw_host == b"xn--mller-kva.de"
assert url.port == 1234
assert url.netloc == b"xn--mller-kva.de:1234"
assert url.path == "/pa th"
assert url.query == b"?search=ab"
assert url.raw_path == b"/pa%20th?search=ab"
assert url.fragment == "anchorlink"
The components of a URL are broken down like this:
https://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink
[scheme] [ username ] [password] [ host ][port][ path ] [ query ] [fragment]
[ userinfo ] [ netloc ][ raw_path ]
Note that:
* `url.scheme` is normalized to always be lowercased.
* `url.host` is normalized to always be lowercased. Internationalized domain
names are represented in unicode, without IDNA encoding applied. For instance:
url = httpx.URL("http://中国.icom.museum")
assert url.host == "中国.icom.museum"
url = httpx.URL("http://xn--fiqs8s.icom.museum")
assert url.host == "中国.icom.museum"
* `url.raw_host` is normalized to always be lowercased, and is IDNA encoded.
url = httpx.URL("http://中国.icom.museum")
assert url.raw_host == b"xn--fiqs8s.icom.museum"
url = httpx.URL("http://xn--fiqs8s.icom.museum")
assert url.raw_host == b"xn--fiqs8s.icom.museum"
* `url.port` is either None or an integer. URLs that include the default port for
"http", "https", "ws", "wss", and "ftp" schemes have their port
normalized to `None`.
assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80")
assert httpx.URL("http://example.com").port is None
assert httpx.URL("http://example.com:80").port is None
* `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work
with `url.username` and `url.password` instead, which handle the URL escaping.
* `url.raw_path` is raw bytes of both the path and query, without URL escaping.
This portion is used as the target when constructing HTTP requests. Usually you'll
want to work with `url.path` instead.
* `url.query` is raw bytes, without URL escaping. A URL query string portion can
only be properly URL escaped when decoding the parameter names and values
themselves.
"""
def __init__(self, url: URL | str = "", **kwargs: typing.Any) -> None:
if kwargs:
allowed = {
"scheme": str,
"username": str,
"password": str,
"userinfo": bytes,
"host": str,
"port": int,
"netloc": bytes,
"path": str,
"query": bytes,
"raw_path": bytes,
"fragment": str,
"params": object,
}
# Perform type checking for all supported keyword arguments.
for key, value in kwargs.items():
if key not in allowed:
message = f"{key!r} is an invalid keyword argument for URL()"
raise TypeError(message)
if value is not None and not isinstance(value, allowed[key]):
expected = allowed[key].__name__
seen = type(value).__name__
message = f"Argument {key!r} must be {expected} but got {seen}"
raise TypeError(message)
if isinstance(value, bytes):
kwargs[key] = value.decode("ascii")
if "params" in kwargs:
# Replace any "params" keyword with the raw "query" instead.
#
# Ensure that empty params use `kwargs["query"] = None` rather
# than `kwargs["query"] = ""`, so that generated URLs do not
# include an empty trailing "?".
params = kwargs.pop("params")
kwargs["query"] = None if not params else str(QueryParams(params))
if isinstance(url, str):
self._uri_reference = urlparse(url, **kwargs)
elif isinstance(url, URL):
self._uri_reference = url._uri_reference.copy_with(**kwargs)
else:
raise TypeError(
"Invalid type for url. Expected str or httpx.URL,"
f" got {type(url)}: {url!r}"
)
@property
def scheme(self) -> str:
"""
The URL scheme, such as "http", "https".
Always normalised to lowercase.
"""
return self._uri_reference.scheme
@property
def raw_scheme(self) -> bytes:
"""
The raw bytes representation of the URL scheme, such as b"http", b"https".
Always normalised to lowercase.
"""
return self._uri_reference.scheme.encode("ascii")
@property
def userinfo(self) -> bytes:
"""
The URL userinfo as a raw bytestring.
For example: b"jo%40email.com:a%20secret".
"""
return self._uri_reference.userinfo.encode("ascii")
@property
def username(self) -> str:
"""
The URL username as a string, with URL decoding applied.
For example: "jo@email.com"
"""
userinfo = self._uri_reference.userinfo
return unquote(userinfo.partition(":")[0])
@property
def password(self) -> str:
"""
The URL password as a string, with URL decoding applied.
For example: "a secret"
"""
userinfo = self._uri_reference.userinfo
return unquote(userinfo.partition(":")[2])
@property
def host(self) -> str:
"""
The URL host as a string.
Always normalized to lowercase, with IDNA hosts decoded into unicode.
Examples:
url = httpx.URL("http://www.EXAMPLE.org")
assert url.host == "www.example.org"
url = httpx.URL("http://中国.icom.museum")
assert url.host == "中国.icom.museum"
url = httpx.URL("http://xn--fiqs8s.icom.museum")
assert url.host == "中国.icom.museum"
url = httpx.URL("https://[::ffff:192.168.0.1]")
assert url.host == "::ffff:192.168.0.1"
"""
host: str = self._uri_reference.host
if host.startswith("xn--"):
host = idna.decode(host)
return host
@property
def raw_host(self) -> bytes:
"""
The raw bytes representation of the URL host.
Always normalized to lowercase, and IDNA encoded.
Examples:
url = httpx.URL("http://www.EXAMPLE.org")
assert url.raw_host == b"www.example.org"
url = httpx.URL("http://中国.icom.museum")
assert url.raw_host == b"xn--fiqs8s.icom.museum"
url = httpx.URL("http://xn--fiqs8s.icom.museum")
assert url.raw_host == b"xn--fiqs8s.icom.museum"
url = httpx.URL("https://[::ffff:192.168.0.1]")
assert url.raw_host == b"::ffff:192.168.0.1"
"""
return self._uri_reference.host.encode("ascii")
@property
def port(self) -> int | None:
"""
The URL port as an integer.
Note that the URL class performs port normalization as per the WHATWG spec.
Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always
treated as `None`.
For example:
assert httpx.URL("http://www.example.com") == httpx.URL("http://www.example.com:80")
assert httpx.URL("http://www.example.com:80").port is None
"""
return self._uri_reference.port
@property
def netloc(self) -> bytes:
"""
Either `<host>` or `<host>:<port>` as bytes.
Always normalized to lowercase, and IDNA encoded.
This property may be used for generating the value of a request
"Host" header.
"""
return self._uri_reference.netloc.encode("ascii")
@property
def path(self) -> str:
"""
The URL path as a string. Excluding the query string, and URL decoded.
For example:
url = httpx.URL("https://example.com/pa%20th")
assert url.path == "/pa th"
"""
path = self._uri_reference.path or "/"
return unquote(path)
@property
def query(self) -> bytes:
"""
The URL query string, as raw bytes, excluding the leading b"?".
This is necessarily a bytewise interface, because we cannot
perform URL decoding of this representation until we've parsed
the keys and values into a QueryParams instance.
For example:
url = httpx.URL("https://example.com/?filter=some%20search%20terms")
assert url.query == b"filter=some%20search%20terms"
"""
query = self._uri_reference.query or ""
return query.encode("ascii")
@property
def params(self) -> QueryParams:
"""
The URL query parameters, neatly parsed and packaged into an immutable
multidict representation.
"""
return QueryParams(self._uri_reference.query)
@property
def raw_path(self) -> bytes:
"""
The complete URL path and query string as raw bytes.
Used as the target when constructing HTTP requests.
For example:
GET /users?search=some%20text HTTP/1.1
Host: www.example.org
Connection: close
"""
path = self._uri_reference.path or "/"
if self._uri_reference.query is not None:
path += "?" + self._uri_reference.query
return path.encode("ascii")
@property
def fragment(self) -> str:
"""
The URL fragments, as used in HTML anchors.
As a string, without the leading '#'.
"""
return unquote(self._uri_reference.fragment or "")
@property
def is_absolute_url(self) -> bool:
"""
Return `True` for absolute URLs such as 'http://example.com/path',
and `False` for relative URLs such as '/path'.
"""
# We don't use `.is_absolute` from `rfc3986` because it treats
# URLs with a fragment portion as not absolute.
# What we actually care about is if the URL provides
# a scheme and hostname to which connections should be made.
return bool(self._uri_reference.scheme and self._uri_reference.host)
@property
def is_relative_url(self) -> bool:
"""
Return `False` for absolute URLs such as 'http://example.com/path',
and `True` for relative URLs such as '/path'.
"""
return not self.is_absolute_url
def copy_with(self, **kwargs: typing.Any) -> URL:
"""
Copy this URL, returning a new URL with some components altered.
Accepts the same set of parameters as the components that are made
available via properties on the `URL` class.
For example:
url = httpx.URL("https://www.example.com").copy_with(
username="jo@gmail.com", password="a secret"
)
assert url == "https://jo%40email.com:a%20secret@www.example.com"
"""
return URL(self, **kwargs)
def copy_set_param(self, key: str, value: typing.Any = None) -> URL:
return self.copy_with(params=self.params.set(key, value))
def copy_add_param(self, key: str, value: typing.Any = None) -> URL:
return self.copy_with(params=self.params.add(key, value))
def copy_remove_param(self, key: str) -> URL:
return self.copy_with(params=self.params.remove(key))
def copy_merge_params(self, params: QueryParamTypes) -> URL:
return self.copy_with(params=self.params.merge(params))
def join(self, url: URL | str) -> URL:
"""
Return an absolute URL, using this URL as the base.
Eg.
url = httpx.URL("https://www.example.com/test")
url = url.join("/new/path")
assert url == "https://www.example.com/new/path"
"""
from urllib.parse import urljoin
return URL(urljoin(str(self), str(URL(url))))
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, (URL, str)) and str(self) == str(URL(other))
def __str__(self) -> str:
return str(self._uri_reference)
def __repr__(self) -> str:
scheme, userinfo, host, port, path, query, fragment = self._uri_reference
if ":" in userinfo:
# Mask any password component.
userinfo = f'{userinfo.split(":")[0]}:[secure]'
authority = "".join(
[
f"{userinfo}@" if userinfo else "",
f"[{host}]" if ":" in host else host,
f":{port}" if port is not None else "",
]
)
url = "".join(
[
f"{self.scheme}:" if scheme else "",
f"//{authority}" if authority else "",
path,
f"?{query}" if query is not None else "",
f"#{fragment}" if fragment is not None else "",
]
)
return f"{self.__class__.__name__}({url!r})"
@property
def raw(self) -> tuple[bytes, bytes, int, bytes]: # pragma: nocover
import collections
import warnings
warnings.warn("URL.raw is deprecated.")
RawURL = collections.namedtuple(
"RawURL", ["raw_scheme", "raw_host", "port", "raw_path"]
)
return RawURL(
raw_scheme=self.raw_scheme,
raw_host=self.raw_host,
port=self.port,
raw_path=self.raw_path,
)
class QueryParams(typing.Mapping[str, str]):
"""
URL query parameters, as a multi-dict.
"""
def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
assert len(args) < 2, "Too many arguments."
assert not (args and kwargs), "Cannot mix named and unnamed arguments."
value = args[0] if args else kwargs
if value is None or isinstance(value, (str, bytes)):
value = value.decode("ascii") if isinstance(value, bytes) else value
self._dict = parse_qs(value, keep_blank_values=True)
elif isinstance(value, QueryParams):
self._dict = {k: list(v) for k, v in value._dict.items()}
else:
dict_value: dict[typing.Any, list[typing.Any]] = {}
if isinstance(value, (list, tuple)):
# Convert list inputs like:
# [("a", "123"), ("a", "456"), ("b", "789")]
# To a dict representation, like:
# {"a": ["123", "456"], "b": ["789"]}
for item in value:
dict_value.setdefault(item[0], []).append(item[1])
else:
# Convert dict inputs like:
# {"a": "123", "b": ["456", "789"]}
# To dict inputs where values are always lists, like:
# {"a": ["123"], "b": ["456", "789"]}
dict_value = {
k: list(v) if isinstance(v, (list, tuple)) else [v]
for k, v in value.items()
}
# Ensure that keys and values are neatly coerced to strings.
# We coerce values `True` and `False` to JSON-like "true" and "false"
# representations, and coerce `None` values to the empty string.
self._dict = {
str(k): [primitive_value_to_str(item) for item in v]
for k, v in dict_value.items()
}
def keys(self) -> typing.KeysView[str]:
"""
Return all the keys in the query params.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.keys()) == ["a", "b"]
"""
return self._dict.keys()
def values(self) -> typing.ValuesView[str]:
"""
Return all the values in the query params. If a key occurs more than once
only the first item for that key is returned.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.values()) == ["123", "789"]
"""
return {k: v[0] for k, v in self._dict.items()}.values()
def items(self) -> typing.ItemsView[str, str]:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.items()) == [("a", "123"), ("b", "789")]
"""
return {k: v[0] for k, v in self._dict.items()}.items()
def multi_items(self) -> list[tuple[str, str]]:
"""
Return all items in the query params. Allow duplicate keys to occur.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
"""
multi_items: list[tuple[str, str]] = []
for k, v in self._dict.items():
multi_items.extend([(k, i) for i in v])
return multi_items
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
"""
Get a value from the query param for a given key. If the key occurs
more than once, then only the first value is returned.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert q.get("a") == "123"
"""
if key in self._dict:
return self._dict[str(key)][0]
return default
def get_list(self, key: str) -> list[str]:
"""
Get all values from the query param for a given key.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert q.get_list("a") == ["123", "456"]
"""
return list(self._dict.get(str(key), []))
def set(self, key: str, value: typing.Any = None) -> QueryParams:
"""
Return a new QueryParams instance, setting the value of a key.
Usage:
q = httpx.QueryParams("a=123")
q = q.set("a", "456")
assert q == httpx.QueryParams("a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = [primitive_value_to_str(value)]
return q
def add(self, key: str, value: typing.Any = None) -> QueryParams:
"""
Return a new QueryParams instance, setting or appending the value of a key.
Usage:
q = httpx.QueryParams("a=123")
q = q.add("a", "456")
assert q == httpx.QueryParams("a=123&a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
return q
def remove(self, key: str) -> QueryParams:
"""
Return a new QueryParams instance, removing the value of a key.
Usage:
q = httpx.QueryParams("a=123")
q = q.remove("a")
assert q == httpx.QueryParams("")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict.pop(str(key), None)
return q
def merge(self, params: QueryParamTypes | None = None) -> QueryParams:
"""
Return a new QueryParams instance, updated with.
Usage:
q = httpx.QueryParams("a=123")
q = q.merge({"b": "456"})
assert q == httpx.QueryParams("a=123&b=456")
q = httpx.QueryParams("a=123")
q = q.merge({"a": "456", "b": "789"})
assert q == httpx.QueryParams("a=456&b=789")
"""
q = QueryParams(params)
q._dict = {**self._dict, **q._dict}
return q
def __getitem__(self, key: typing.Any) -> str:
return self._dict[key][0]
def __contains__(self, key: typing.Any) -> bool:
return key in self._dict
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
def __bool__(self) -> bool:
return bool(self._dict)
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self.multi_items()) == sorted(other.multi_items())
def __str__(self) -> str:
return urlencode(self.multi_items())
def __repr__(self) -> str:
class_name = self.__class__.__name__
query_string = str(self)
return f"{class_name}({query_string!r})"
def update(self, params: QueryParamTypes | None = None) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.merge(...)` to create an updated copy."
)
def __setitem__(self, key: str, value: str) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.set(key, value)` to create an updated copy."
)

View File

@ -0,0 +1,242 @@
from __future__ import annotations
import ipaddress
import os
import re
import typing
from urllib.request import getproxies
from ._types import PrimitiveData
if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL
def primitive_value_to_str(value: PrimitiveData) -> str:
"""
Coerce a primitive data type into a string value.
Note that we prefer JSON-style 'true'/'false' for boolean values here.
"""
if value is True:
return "true"
elif value is False:
return "false"
elif value is None:
return ""
return str(value)
def get_environment_proxies() -> dict[str, str | None]:
"""Gets proxy information from the environment"""
# urllib.request.getproxies() falls back on System
# Registry and Config for proxies on Windows and macOS.
# We don't want to propagate non-HTTP proxies into
# our configuration such as 'TRAVIS_APT_PROXY'.
proxy_info = getproxies()
mounts: dict[str, str | None] = {}
for scheme in ("http", "https", "all"):
if proxy_info.get(scheme):
hostname = proxy_info[scheme]
mounts[f"{scheme}://"] = (
hostname if "://" in hostname else f"http://{hostname}"
)
no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")]
for hostname in no_proxy_hosts:
# See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details
# on how names in `NO_PROXY` are handled.
if hostname == "*":
# If NO_PROXY=* is used or if "*" occurs as any one of the comma
# separated hostnames, then we should just bypass any information
# from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore
# proxies.
return {}
elif hostname:
# NO_PROXY=.google.com is marked as "all://*.google.com,
# which disables "www.google.com" but not "google.com"
# NO_PROXY=google.com is marked as "all://*google.com,
# which disables "www.google.com" and "google.com".
# (But not "wwwgoogle.com")
# NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost"
# NO_PROXY=example.com,::1,localhost,192.168.0.0/16
if "://" in hostname:
mounts[hostname] = None
elif is_ipv4_hostname(hostname):
mounts[f"all://{hostname}"] = None
elif is_ipv6_hostname(hostname):
mounts[f"all://[{hostname}]"] = None
elif hostname.lower() == "localhost":
mounts[f"all://{hostname}"] = None
else:
mounts[f"all://*{hostname}"] = None
return mounts
def to_bytes(value: str | bytes, encoding: str = "utf-8") -> bytes:
return value.encode(encoding) if isinstance(value, str) else value
def to_str(value: str | bytes, encoding: str = "utf-8") -> str:
return value if isinstance(value, str) else value.decode(encoding)
def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr:
return value if isinstance(match_type_of, str) else value.encode()
def unquote(value: str) -> str:
return value[1:-1] if value[0] == value[-1] == '"' else value
def peek_filelike_length(stream: typing.Any) -> int | None:
"""
Given a file-like stream object, return its length in number of bytes
without reading it into memory.
"""
try:
# Is it an actual file?
fd = stream.fileno()
# Yup, seems to be an actual file.
length = os.fstat(fd).st_size
except (AttributeError, OSError):
# No... Maybe it's something that supports random access, like `io.BytesIO`?
try:
# Assuming so, go to end of stream to figure out its length,
# then put it back in place.
offset = stream.tell()
length = stream.seek(0, os.SEEK_END)
stream.seek(offset)
except (AttributeError, OSError):
# Not even that? Sorry, we're doomed...
return None
return length
class URLPattern:
"""
A utility class currently used for making lookups against proxy keys...
# Wildcard matching...
>>> pattern = URLPattern("all://")
>>> pattern.matches(httpx.URL("http://example.com"))
True
# Witch scheme matching...
>>> pattern = URLPattern("https://")
>>> pattern.matches(httpx.URL("https://example.com"))
True
>>> pattern.matches(httpx.URL("http://example.com"))
False
# With domain matching...
>>> pattern = URLPattern("https://example.com")
>>> pattern.matches(httpx.URL("https://example.com"))
True
>>> pattern.matches(httpx.URL("http://example.com"))
False
>>> pattern.matches(httpx.URL("https://other.com"))
False
# Wildcard scheme, with domain matching...
>>> pattern = URLPattern("all://example.com")
>>> pattern.matches(httpx.URL("https://example.com"))
True
>>> pattern.matches(httpx.URL("http://example.com"))
True
>>> pattern.matches(httpx.URL("https://other.com"))
False
# With port matching...
>>> pattern = URLPattern("https://example.com:1234")
>>> pattern.matches(httpx.URL("https://example.com:1234"))
True
>>> pattern.matches(httpx.URL("https://example.com"))
False
"""
def __init__(self, pattern: str) -> None:
from ._urls import URL
if pattern and ":" not in pattern:
raise ValueError(
f"Proxy keys should use proper URL forms rather "
f"than plain scheme strings. "
f'Instead of "{pattern}", use "{pattern}://"'
)
url = URL(pattern)
self.pattern = pattern
self.scheme = "" if url.scheme == "all" else url.scheme
self.host = "" if url.host == "*" else url.host
self.port = url.port
if not url.host or url.host == "*":
self.host_regex: typing.Pattern[str] | None = None
elif url.host.startswith("*."):
# *.example.com should match "www.example.com", but not "example.com"
domain = re.escape(url.host[2:])
self.host_regex = re.compile(f"^.+\\.{domain}$")
elif url.host.startswith("*"):
# *example.com should match "www.example.com" and "example.com"
domain = re.escape(url.host[1:])
self.host_regex = re.compile(f"^(.+\\.)?{domain}$")
else:
# example.com should match "example.com" but not "www.example.com"
domain = re.escape(url.host)
self.host_regex = re.compile(f"^{domain}$")
def matches(self, other: URL) -> bool:
if self.scheme and self.scheme != other.scheme:
return False
if (
self.host
and self.host_regex is not None
and not self.host_regex.match(other.host)
):
return False
if self.port is not None and self.port != other.port:
return False
return True
@property
def priority(self) -> tuple[int, int, int]:
"""
The priority allows URLPattern instances to be sortable, so that
we can match from most specific to least specific.
"""
# URLs with a port should take priority over URLs without a port.
port_priority = 0 if self.port is not None else 1
# Longer hostnames should match first.
host_priority = -len(self.host)
# Longer schemes should match first.
scheme_priority = -len(self.scheme)
return (port_priority, host_priority, scheme_priority)
def __hash__(self) -> int:
return hash(self.pattern)
def __lt__(self, other: URLPattern) -> bool:
return self.priority < other.priority
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, URLPattern) and self.pattern == other.pattern
def is_ipv4_hostname(hostname: str) -> bool:
try:
ipaddress.IPv4Address(hostname.split("/")[0])
except Exception:
return False
return True
def is_ipv6_hostname(hostname: str) -> bool:
try:
ipaddress.IPv6Address(hostname.split("/")[0])
except Exception:
return False
return True

View File