second commit
This commit is contained in:
41
env/lib/python3.11/site-packages/starlette/middleware/__init__.py
vendored
Normal file
41
env/lib/python3.11/site-packages/starlette/middleware/__init__.py
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any, Iterator, Protocol
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class _MiddlewareFactory(Protocol[P]):
|
||||
def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(
|
||||
self,
|
||||
cls: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
self.cls = cls
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
as_tuple = (self.cls, self.args, self.kwargs)
|
||||
return iter(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
args_strings = [f"{value!r}" for value in self.args]
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
|
||||
name = getattr(self.cls, "__name__", "")
|
||||
args_repr = ", ".join([name] + args_strings + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/authentication.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/authentication.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/base.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/base.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/cors.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/cors.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/errors.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/errors.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/exceptions.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/exceptions.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/gzip.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/gzip.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/httpsredirect.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/httpsredirect.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/sessions.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/sessions.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/trustedhost.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/trustedhost.cpython-311.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/wsgi.cpython-311.pyc
vendored
Normal file
BIN
env/lib/python3.11/site-packages/starlette/middleware/__pycache__/wsgi.cpython-311.pyc
vendored
Normal file
Binary file not shown.
52
env/lib/python3.11/site-packages/starlette/middleware/authentication.py
vendored
Normal file
52
env/lib/python3.11/site-packages/starlette/middleware/authentication.py
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
AuthenticationBackend,
|
||||
AuthenticationError,
|
||||
UnauthenticatedUser,
|
||||
)
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
conn = HTTPConnection(scope)
|
||||
try:
|
||||
auth_result = await self.backend.authenticate(conn)
|
||||
except AuthenticationError as exc:
|
||||
response = self.on_error(conn, exc)
|
||||
if scope["type"] == "websocket":
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
else:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if auth_result is None:
|
||||
auth_result = AuthCredentials(), UnauthenticatedUser()
|
||||
scope["auth"], scope["user"] = auth_result
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
@staticmethod
|
||||
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
|
||||
return PlainTextResponse(str(exc), status_code=400)
|
228
env/lib/python3.11/site-packages/starlette/middleware/base.py
vendored
Normal file
228
env/lib/python3.11/site-packages/starlette/middleware/base.py
vendored
Normal file
@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.requests import ClientDisconnect, Request
|
||||
from starlette.responses import AsyncContentStream, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
class _CachedRequest(Request):
|
||||
"""
|
||||
If the user calls Request.body() from their dispatch function
|
||||
we cache the entire request body in memory and pass that to downstream middlewares,
|
||||
but if they call Request.stream() then all we do is send an
|
||||
empty body so that downstream things don't hang forever.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive):
|
||||
super().__init__(scope, receive)
|
||||
self._wrapped_rcv_disconnected = False
|
||||
self._wrapped_rcv_consumed = False
|
||||
self._wrapped_rc_stream = self.stream()
|
||||
|
||||
async def wrapped_receive(self) -> Message:
|
||||
# wrapped_rcv state 1: disconnected
|
||||
if self._wrapped_rcv_disconnected:
|
||||
# we've already sent a disconnect to the downstream app
|
||||
# we don't need to wait to get another one
|
||||
# (although most ASGI servers will just keep sending it)
|
||||
return {"type": "http.disconnect"}
|
||||
# wrapped_rcv state 1: consumed but not yet disconnected
|
||||
if self._wrapped_rcv_consumed:
|
||||
# since the downstream app has consumed us all that is left
|
||||
# is to send it a disconnect
|
||||
if self._is_disconnected:
|
||||
# the middleware has already seen the disconnect
|
||||
# since we know the client is disconnected no need to wait
|
||||
# for the message
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
# we don't know yet if the client is disconnected or not
|
||||
# so we'll wait until we get that message
|
||||
msg = await self.receive()
|
||||
if msg["type"] != "http.disconnect": # pragma: no cover
|
||||
# at this point a disconnect is all that we should be receiving
|
||||
# if we get something else, things went wrong somewhere
|
||||
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return msg
|
||||
|
||||
# wrapped_rcv state 3: not yet consumed
|
||||
if getattr(self, "_body", None) is not None:
|
||||
# body() was called, we return it even if the client disconnected
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": self._body,
|
||||
"more_body": False,
|
||||
}
|
||||
elif self._stream_consumed:
|
||||
# stream() was called to completion
|
||||
# return an empty body so that downstream apps don't hang
|
||||
# waiting for a disconnect
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
else:
|
||||
# body() was never called and stream() wasn't consumed
|
||||
try:
|
||||
stream = self.stream()
|
||||
chunk = await stream.__anext__()
|
||||
self._wrapped_rcv_consumed = self._stream_consumed
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": not self._stream_consumed,
|
||||
}
|
||||
except ClientDisconnect:
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = _CachedRequest(scope, receive)
|
||||
wrapped_receive = request.wrapped_receive
|
||||
response_sent = anyio.Event()
|
||||
|
||||
async def call_next(request: Request) -> Response:
|
||||
app_exc: Exception | None = None
|
||||
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||||
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||||
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||||
|
||||
async def receive_or_disconnect() -> Message:
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
|
||||
result = await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
return result
|
||||
|
||||
task_group.start_soon(wrap, response_sent.wait)
|
||||
message = await wrap(wrapped_receive)
|
||||
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
return message
|
||||
|
||||
async def close_recv_stream_on_response_sent() -> None:
|
||||
await response_sent.wait()
|
||||
recv_stream.close()
|
||||
|
||||
async def send_no_error(message: Message) -> None:
|
||||
try:
|
||||
await send_stream.send(message)
|
||||
except anyio.BrokenResourceError:
|
||||
# recv_stream has been closed, i.e. response_sent has been set.
|
||||
return
|
||||
|
||||
async def coro() -> None:
|
||||
nonlocal app_exc
|
||||
|
||||
async with send_stream:
|
||||
try:
|
||||
await self.app(scope, receive_or_disconnect, send_no_error)
|
||||
except Exception as exc:
|
||||
app_exc = exc
|
||||
|
||||
task_group.start_soon(close_recv_stream_on_response_sent)
|
||||
task_group.start_soon(coro)
|
||||
|
||||
try:
|
||||
message = await recv_stream.receive()
|
||||
info = message.get("info", None)
|
||||
if message["type"] == "http.response.debug" and info is not None:
|
||||
message = await recv_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
raise RuntimeError("No response returned.")
|
||||
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
async with recv_stream:
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
|
||||
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
with collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, wrapped_receive, send)
|
||||
response_sent.set()
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class _StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: AsyncContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
info: typing.Mapping[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
self.info = info
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
self.background = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.info is not None:
|
||||
await send({"type": "http.response.debug", "info": self.info})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
async for chunk in self.body_iterator:
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
if self.background:
|
||||
await self.background()
|
172
env/lib/python3.11/site-packages/starlette/middleware/cors.py
vendored
Normal file
172
env/lib/python3.11/site-packages/starlette/middleware/cors.py
vendored
Normal file
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
|
||||
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
|
||||
|
||||
|
||||
class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: str | None = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
if "*" in allow_methods:
|
||||
allow_methods = ALL_METHODS
|
||||
|
||||
compiled_allow_origin_regex = None
|
||||
if allow_origin_regex is not None:
|
||||
compiled_allow_origin_regex = re.compile(allow_origin_regex)
|
||||
|
||||
allow_all_origins = "*" in allow_origins
|
||||
allow_all_headers = "*" in allow_headers
|
||||
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
|
||||
|
||||
simple_headers = {}
|
||||
if allow_all_origins:
|
||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||
if allow_credentials:
|
||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
if expose_headers:
|
||||
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
|
||||
|
||||
preflight_headers = {}
|
||||
if preflight_explicit_allow_origin:
|
||||
# The origin value will be set in preflight_response() if it is allowed.
|
||||
preflight_headers["Vary"] = "Origin"
|
||||
else:
|
||||
preflight_headers["Access-Control-Allow-Origin"] = "*"
|
||||
preflight_headers.update(
|
||||
{
|
||||
"Access-Control-Allow-Methods": ", ".join(allow_methods),
|
||||
"Access-Control-Max-Age": str(max_age),
|
||||
}
|
||||
)
|
||||
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
|
||||
if allow_headers and not allow_all_headers:
|
||||
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
|
||||
if allow_credentials:
|
||||
preflight_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
self.app = app
|
||||
self.allow_origins = allow_origins
|
||||
self.allow_methods = allow_methods
|
||||
self.allow_headers = [h.lower() for h in allow_headers]
|
||||
self.allow_all_origins = allow_all_origins
|
||||
self.allow_all_headers = allow_all_headers
|
||||
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
|
||||
self.allow_origin_regex = compiled_allow_origin_regex
|
||||
self.simple_headers = simple_headers
|
||||
self.preflight_headers = preflight_headers
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
method = scope["method"]
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
|
||||
if origin is None:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
if method == "OPTIONS" and "access-control-request-method" in headers:
|
||||
response = self.preflight_response(request_headers=headers)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.simple_response(scope, receive, send, request_headers=headers)
|
||||
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
|
||||
def preflight_response(self, request_headers: Headers) -> Response:
|
||||
requested_origin = request_headers["origin"]
|
||||
requested_method = request_headers["access-control-request-method"]
|
||||
requested_headers = request_headers.get("access-control-request-headers")
|
||||
|
||||
headers = dict(self.preflight_headers)
|
||||
failures = []
|
||||
|
||||
if self.is_allowed_origin(origin=requested_origin):
|
||||
if self.preflight_explicit_allow_origin:
|
||||
# The "else" case is already accounted for in self.preflight_headers
|
||||
# and the value would be "*".
|
||||
headers["Access-Control-Allow-Origin"] = requested_origin
|
||||
else:
|
||||
failures.append("origin")
|
||||
|
||||
if requested_method not in self.allow_methods:
|
||||
failures.append("method")
|
||||
|
||||
# If we allow all headers, then we have to mirror back any requested
|
||||
# headers in the response.
|
||||
if self.allow_all_headers and requested_headers is not None:
|
||||
headers["Access-Control-Allow-Headers"] = requested_headers
|
||||
elif requested_headers is not None:
|
||||
for header in [h.lower() for h in requested_headers.split(",")]:
|
||||
if header.strip() not in self.allow_headers:
|
||||
failures.append("headers")
|
||||
break
|
||||
|
||||
# We don't strictly need to use 400 responses here, since its up to
|
||||
# the browser to enforce the CORS policy, but its more informative
|
||||
# if we do.
|
||||
if failures:
|
||||
failure_text = "Disallowed CORS " + ", ".join(failures)
|
||||
return PlainTextResponse(failure_text, status_code=400, headers=headers)
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
headers.update(self.simple_headers)
|
||||
origin = request_headers["Origin"]
|
||||
has_cookie = "cookie" in request_headers
|
||||
|
||||
# If request includes any cookie headers, then we must respond
|
||||
# with the specific origin instead of '*'.
|
||||
if self.allow_all_origins and has_cookie:
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
|
||||
# If we only allow specific origins, then we have to mirror back
|
||||
# the Origin header in the response.
|
||||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
|
||||
await send(message)
|
||||
|
||||
@staticmethod
|
||||
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers.add_vary_header("Origin")
|
260
env/lib/python3.11/site-packages/starlette/middleware/errors.py
vendored
Normal file
260
env/lib/python3.11/site-packages/starlette/middleware/errors.py
vendored
Normal file
@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
STYLES = """
|
||||
p {
|
||||
color: #211c1c;
|
||||
}
|
||||
.traceback-container {
|
||||
border: 1px solid #038BB8;
|
||||
}
|
||||
.traceback-title {
|
||||
background-color: #038BB8;
|
||||
color: lemonchiffon;
|
||||
padding: 12px;
|
||||
font-size: 20px;
|
||||
margin-top: 0px;
|
||||
}
|
||||
.frame-line {
|
||||
padding-left: 10px;
|
||||
font-family: monospace;
|
||||
}
|
||||
.frame-filename {
|
||||
font-family: monospace;
|
||||
}
|
||||
.center-line {
|
||||
background-color: #038BB8;
|
||||
color: #f9f6e1;
|
||||
padding: 5px 0px 5px 5px;
|
||||
}
|
||||
.lineno {
|
||||
margin-right: 5px;
|
||||
}
|
||||
.frame-title {
|
||||
font-weight: unset;
|
||||
padding: 10px 10px 10px 10px;
|
||||
background-color: #E4F4FD;
|
||||
margin-right: 10px;
|
||||
color: #191f21;
|
||||
font-size: 17px;
|
||||
border: 1px solid #c7dce8;
|
||||
}
|
||||
.collapse-btn {
|
||||
float: right;
|
||||
padding: 0px 5px 1px 5px;
|
||||
border: solid 1px #96aebb;
|
||||
cursor: pointer;
|
||||
}
|
||||
.collapsed {
|
||||
display: none;
|
||||
}
|
||||
.source-code {
|
||||
font-family: courier;
|
||||
font-size: small;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
"""
|
||||
|
||||
JS = """
|
||||
<script type="text/javascript">
|
||||
function collapse(element){
|
||||
const frameId = element.getAttribute("data-frame-id");
|
||||
const frame = document.getElementById(frameId);
|
||||
|
||||
if (frame.classList.contains("collapsed")){
|
||||
element.innerHTML = "‒";
|
||||
frame.classList.remove("collapsed");
|
||||
} else {
|
||||
element.innerHTML = "+";
|
||||
frame.classList.add("collapsed");
|
||||
}
|
||||
}
|
||||
</script>
|
||||
"""
|
||||
|
||||
TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<style type='text/css'>
|
||||
{styles}
|
||||
</style>
|
||||
<title>Starlette Debugger</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>500 Server Error</h1>
|
||||
<h2>{error}</h2>
|
||||
<div class="traceback-container">
|
||||
<p class="traceback-title">Traceback</p>
|
||||
<div>{exc_html}</div>
|
||||
</div>
|
||||
{js}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
FRAME_TEMPLATE = """
|
||||
<div>
|
||||
<p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
|
||||
line <i>{frame_lineno}</i>,
|
||||
in <b>{frame_name}</b>
|
||||
<span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
|
||||
</p>
|
||||
<div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
|
||||
</div>
|
||||
""" # noqa: E501
|
||||
|
||||
LINE = """
|
||||
<p><span class="frame-line">
|
||||
<span class="lineno">{lineno}.</span> {line}</span></p>
|
||||
"""
|
||||
|
||||
CENTER_LINE = """
|
||||
<p class="center-line"><span class="frame-line center-line">
|
||||
<span class="lineno">{lineno}.</span> {line}</span></p>
|
||||
"""
|
||||
|
||||
|
||||
class ServerErrorMiddleware:
|
||||
"""
|
||||
Handles returning 500 responses when a server error occurs.
|
||||
|
||||
If 'debug' is set, then traceback responses will be returned,
|
||||
otherwise the designated 'handler' will be called.
|
||||
|
||||
This middleware class should generally be used to wrap *everything*
|
||||
else up, so that unhandled exceptions anywhere in the stack
|
||||
always result in an appropriate 500 response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.handler = handler
|
||||
self.debug = debug
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def _send(message: Message) -> None:
|
||||
nonlocal response_started, send
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, _send)
|
||||
except Exception as exc:
|
||||
request = Request(scope)
|
||||
if self.debug:
|
||||
# In debug mode, return traceback responses.
|
||||
response = self.debug_response(request, exc)
|
||||
elif self.handler is None:
|
||||
# Use our default 500 error handler.
|
||||
response = self.error_response(request, exc)
|
||||
else:
|
||||
# Use an installed 500 error handler.
|
||||
if is_async_callable(self.handler):
|
||||
response = await self.handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(self.handler, request, exc)
|
||||
|
||||
if not response_started:
|
||||
await response(scope, receive, send)
|
||||
|
||||
# We always continue to raise the exception.
|
||||
# This allows servers to log the error, or allows test clients
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc
|
||||
|
||||
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
"lineno": (frame_lineno - frame_index) + index,
|
||||
}
|
||||
|
||||
if index != frame_index:
|
||||
return LINE.format(**values)
|
||||
return CENTER_LINE.format(**values)
|
||||
|
||||
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
|
||||
code_context = "".join(
|
||||
self.format_line(
|
||||
index,
|
||||
line,
|
||||
frame.lineno,
|
||||
frame.index, # type: ignore[arg-type]
|
||||
)
|
||||
for index, line in enumerate(frame.code_context or [])
|
||||
)
|
||||
|
||||
values = {
|
||||
# HTML escape - filename could contain < or >, especially if it's a virtual
|
||||
# file e.g. <stdin> in the REPL
|
||||
"frame_filename": html.escape(frame.filename),
|
||||
"frame_lineno": frame.lineno,
|
||||
# HTML escape - if you try very hard it's possible to name a function with <
|
||||
# or >
|
||||
"frame_name": html.escape(frame.function),
|
||||
"code_context": code_context,
|
||||
"collapsed": "collapsed" if is_collapsed else "",
|
||||
"collapse_button": "+" if is_collapsed else "‒",
|
||||
}
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
exc_traceback = exc.__traceback__
|
||||
if exc_traceback is not None:
|
||||
frames = inspect.getinnerframes(exc_traceback, limit)
|
||||
for frame in reversed(frames):
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type_str
|
||||
else: # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type.__name__
|
||||
|
||||
# escape error class and text
|
||||
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
def generate_plain_text(self, exc: Exception) -> str:
|
||||
return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||||
|
||||
def debug_response(self, request: Request, exc: Exception) -> Response:
|
||||
accept = request.headers.get("accept", "")
|
||||
|
||||
if "text/html" in accept:
|
||||
content = self.generate_html(exc)
|
||||
return HTMLResponse(content, status_code=500)
|
||||
content = self.generate_plain_text(exc)
|
||||
return PlainTextResponse(content, status_code=500)
|
||||
|
||||
def error_response(self, request: Request, exc: Exception) -> Response:
|
||||
return PlainTextResponse("Internal Server Error", status_code=500)
|
72
env/lib/python3.11/site-packages/starlette/middleware/exceptions.py
vendored
Normal file
72
env/lib/python3.11/site-packages/starlette/middleware/exceptions.py
vendored
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette._exception_handler import (
|
||||
ExceptionHandlers,
|
||||
StatusHandlers,
|
||||
wrap_app_handling_exceptions,
|
||||
)
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers: StatusHandlers = {}
|
||||
self._exception_handlers: ExceptionHandlers = {
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None:
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
else:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
scope["starlette.exception_handlers"] = (
|
||||
self._exception_handlers,
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
conn = WebSocket(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
||||
|
||||
def http_exception(self, request: Request, exc: Exception) -> Response:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|
108
env/lib/python3.11/site-packages/starlette/middleware/gzip.py
vendored
Normal file
108
env/lib/python3.11/site-packages/starlette/middleware/gzip.py
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
import gzip
|
||||
import io
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.compresslevel = compresslevel
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "http":
|
||||
headers = Headers(scope=scope)
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
|
||||
await responder(scope, receive, send)
|
||||
return
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class GZipResponder:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.send: Send = unattached_send
|
||||
self.initial_message: Message = {}
|
||||
self.started = False
|
||||
self.content_encoding_set = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
with self.gzip_buffer, self.gzip_file:
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
# modify the outgoing headers correctly.
|
||||
self.initial_message = message
|
||||
headers = Headers(raw=self.initial_message["headers"])
|
||||
self.content_encoding_set = "content-encoding" in headers
|
||||
elif message_type == "http.response.body" and self.content_encoding_set:
|
||||
if not self.started:
|
||||
self.started = True
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif message_type == "http.response.body" and not self.started:
|
||||
self.started = True
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply GZip to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard GZip response.
|
||||
self.gzip_file.write(body)
|
||||
self.gzip_file.close()
|
||||
body = self.gzip_buffer.getvalue()
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers["Content-Length"] = str(len(body))
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming GZip response.
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
del headers["Content-Length"]
|
||||
|
||||
self.gzip_file.write(body)
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
# Remaining body in streaming GZip response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
await self.send(message)
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> typing.NoReturn:
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
19
env/lib/python3.11/site-packages/starlette/middleware/httpsredirect.py
vendored
Normal file
19
env/lib/python3.11/site-packages/starlette/middleware/httpsredirect.py
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
from starlette.datastructures import URL
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class HTTPSRedirectMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
|
||||
url = URL(scope=scope)
|
||||
redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme]
|
||||
netloc = url.hostname if url.port in (80, 443) else url.netloc
|
||||
url = url.replace(scheme=redirect_scheme, netloc=netloc)
|
||||
response = RedirectResponse(url, status_code=307)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
85
env/lib/python3.11/site-packages/starlette/middleware/sessions.py
vendored
Normal file
85
env/lib/python3.11/site-packages/starlette/middleware/sessions.py
vendored
Normal file
@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
import itsdangerous
|
||||
from itsdangerous.exc import BadSignature
|
||||
|
||||
from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: str | Secret,
|
||||
session_cookie: str = "session",
|
||||
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
path: str = "/",
|
||||
same_site: typing.Literal["lax", "strict", "none"] = "lax",
|
||||
https_only: bool = False,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
self.session_cookie = session_cookie
|
||||
self.max_age = max_age
|
||||
self.path = path
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
if domain is not None:
|
||||
self.security_flags += f"; domain={domain}"
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
connection = HTTPConnection(scope)
|
||||
initial_session_was_empty = True
|
||||
|
||||
if self.session_cookie in connection.cookies:
|
||||
data = connection.cookies[self.session_cookie].encode("utf-8")
|
||||
try:
|
||||
data = self.signer.unsign(data, max_age=self.max_age)
|
||||
scope["session"] = json.loads(b64decode(data))
|
||||
initial_session_was_empty = False
|
||||
except BadSignature:
|
||||
scope["session"] = {}
|
||||
else:
|
||||
scope["session"] = {}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
if scope["session"]:
|
||||
# We have session data to persist.
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data=data.decode("utf-8"),
|
||||
path=self.path,
|
||||
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
|
||||
security_flags=self.security_flags,
|
||||
)
|
||||
headers.append("Set-Cookie", header_value)
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data="null",
|
||||
path=self.path,
|
||||
expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
|
||||
security_flags=self.security_flags,
|
||||
)
|
||||
headers.append("Set-Cookie", header_value)
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
60
env/lib/python3.11/site-packages/starlette/middleware/trustedhost.py
vendored
Normal file
60
env/lib/python3.11/site-packages/starlette/middleware/trustedhost.py
vendored
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
||||
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Sequence[str] | None = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
allowed_hosts = ["*"]
|
||||
|
||||
for pattern in allowed_hosts:
|
||||
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
||||
if pattern.startswith("*") and pattern != "*":
|
||||
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
||||
self.app = app
|
||||
self.allowed_hosts = list(allowed_hosts)
|
||||
self.allow_any = "*" in allowed_hosts
|
||||
self.www_redirect = www_redirect
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.allow_any or scope["type"] not in (
|
||||
"http",
|
||||
"websocket",
|
||||
): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
found_www_redirect = True
|
||||
|
||||
if is_valid_host:
|
||||
await self.app(scope, receive, send)
|
||||
else:
|
||||
response: Response
|
||||
if found_www_redirect and self.www_redirect:
|
||||
url = URL(scope=scope)
|
||||
redirect_url = url.replace(netloc="www." + url.netloc)
|
||||
response = RedirectResponse(url=str(redirect_url))
|
||||
else:
|
||||
response = PlainTextResponse("Invalid host header", status_code=400)
|
||||
await response(scope, receive, send)
|
152
env/lib/python3.11/site-packages/starlette/middleware/wsgi.py
vendored
Normal file
152
env/lib/python3.11/site-packages/starlette/middleware/wsgi.py
vendored
Normal file
@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
warnings.warn(
|
||||
"starlette.middleware.wsgi is deprecated and will be removed in a future release. "
|
||||
"Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": io.BytesIO(body),
|
||||
"wsgi.errors": sys.stdout,
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
server = scope.get("server") or ("localhost", 80)
|
||||
environ["SERVER_NAME"] = server[0]
|
||||
environ["SERVER_PORT"] = server[1]
|
||||
|
||||
# Get client IP address
|
||||
if scope.get("client"):
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
# Go through headers and make them into environ entries
|
||||
for name, value in scope.get("headers", []):
|
||||
name = name.decode("latin1")
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = f"HTTP_{name}".upper().replace("-", "_")
|
||||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
|
||||
# case
|
||||
value = value.decode("latin1")
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
|
||||
|
||||
class WSGIMiddleware:
|
||||
def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "http"
|
||||
responder = WSGIResponder(self.app, scope)
|
||||
await responder(receive, send)
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||||
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||||
|
||||
def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
|
||||
self.response_started = False
|
||||
self.exc_info: typing.Any = None
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
body = b""
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
environ = build_environ(self.scope, body)
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
task_group.start_soon(self.sender, send)
|
||||
async with self.stream_send:
|
||||
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
|
||||
|
||||
async def sender(self, send: Send) -> None:
|
||||
async with self.stream_receive:
|
||||
async for message in self.stream_receive:
|
||||
await send(message)
|
||||
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
headers = [
|
||||
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
|
||||
for name, value in response_headers
|
||||
]
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
def wsgi(
|
||||
self,
|
||||
environ: dict[str, typing.Any],
|
||||
start_response: typing.Callable[..., typing.Any],
|
||||
) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True},
|
||||
)
|
||||
|
||||
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|
Reference in New Issue
Block a user