6

我有这样的中间件

class RequestContext(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
        request_id = request_ctx.set(str(uuid4()))  # generate uuid to request
        body = await request.body()
        if body:
            logger.info(...)  # log request with body
        else:
            logger.info(...)  # log request without body
 
        response = await call_next(request)
        response.headers['X-Request-ID'] = request_ctx.get()
        logger.info("%s" % (response.status_code))
        request_ctx.reset(request_id)

        return response

因此,该行body = await request.body()冻结了所有具有正文的请求,并且我从所有这些请求中获得了 504。在这种情况下,如何安全地阅读请求正文?我只想记录请求参数。

4

5 回答 5

4

我不会创建继承自 BaseHTTPMiddleware 的中间件,因为它有一些问题,FastAPI 让您有机会创建自己的路由器,根据我的经验,这种方法要好得多。

from fastapi import APIRouter, FastAPI, Request, Response, Body
from fastapi.routing import APIRoute

from typing import Callable, List
from uuid import uuid4


class ContextIncludedRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            request_id = str(uuid4())
            response: Response = await original_route_handler(request)

            if await request.body():
                print(await request.body())

            response.headers["Request-ID"] = request_id
            return response

        return custom_route_handler


app = FastAPI()
router = APIRouter(route_class=ContextIncludedRoute)


@router.post("/context")
async def non_default_router(bod: List[str] = Body(...)):
    return bod


app.include_router(router)

按预期工作。

b'["string"]'
INFO:     127.0.0.1:49784 - "POST /context HTTP/1.1" 200 OK
于 2020-09-29T10:33:34.943 回答
3

如果您仍然想使用 BaseHTTP,我最近遇到了这个问题并想出了一个解决方案:

中间件代码

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import json
from .async_iterator_wrapper import async_iterator_wrapper as aiwrap

class some_middleware(BaseHTTPMiddleware):
   async def dispatch(self, request:Request, call_next:RequestResponseEndpoint):
      # --------------------------
      # DO WHATEVER YOU TO DO HERE
      #---------------------------
      
      response = await call_next(request)

      # Consuming FastAPI response and grabbing body here
      resp_body = [section async for section in response.__dict__['body_iterator']]
      # Repairing FastAPI response
      response.__setattr__('body_iterator', aiwrap(resp_body)

      # Formatting response body for logging
      try:
         resp_body = json.loads(resp_body[0].decode())
      except:
         resp_body = str(resp_body)

来自 Python 3 异步 for 循环的TypeError 的 async_iterator_wrapper 代码

class async_iterator_wrapper:
    def __init__(self, obj):
        self._it = iter(obj)
    def __aiter__(self):
        return self
    async def __anext__(self):
        try:
            value = next(self._it)
        except StopIteration:
            raise StopAsyncIteration
        return value

我真的希望这可以帮助别人!我发现这对记录非常有帮助。

非常感谢 @Eddified 的 aiwrap 课程

于 2020-10-14T17:11:32.787 回答
1

原来await request.json()每个请求周期只能调用一次。因此,如果您需要访问多个中间件中的请求主体以进行过滤或身份验证等,那么有一种解决方法是创建一个自定义中间件,该中间件将请求主体的内容复制到 request.state 中。应尽早加载中间件。然后链中的每个中间件或控制器可以从 request.state 访问请求正文,而不是await request.json()再次调用。这是一个例子:

class CopyRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        request_body = await request.json()
        request.state.body = request_body

        response = await call_next(request)
        return response

class LogRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Since it'll be loaded after CopyRequestMiddleware it can access request.state.body.
        request_body = request.state.body
        print(request_body)
    
        response = await call_next(request)
        return response

控制器也会从 request.state 访问请求体

request_body = request.state.body
于 2021-11-13T04:12:14.423 回答
0

只是因为尚未说明此类解决方案,但它对我有用:

from typing import Callable, Awaitable

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.concurrency import iterate_in_threadpool

class LogStatsMiddleware(BaseHTTPMiddleware):
    async def dispatch(  # type: ignore
        self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]],
    ) -> Response:
        response = await call_next(request)
        response_body = [section async for section in response.body_iterator]
        response.body_iterator = iterate_in_threadpool(iter(response_body))
        logging.info(f"response_body={response_body[0].decode()}")
        return response

def init_app(app):
    app.add_middleware(LogStatsMiddleware)

iterate_in_threadpool实际上从迭代器对象异步迭代器

如果你看一下starlette.responses.StreamingResponse你会看到的实现,这个函数正是用于这个

于 2022-02-02T09:34:00.303 回答
0

如果您只想读取请求参数,我发现的最佳解决方案是实现“route_class”并在创建时将其添加为 arg ,这是因为在中间件中fastapi.APIRouter解析请求被认为是有问题 的路由处理程序背后的意图来自我理解是将异常处理逻辑附加到特定路由器,但由于它在每次路由调用之前被调用,您可以使用它来访问请求参数

Fastapi 文档

您可以执行以下操作:

class MyRequestLoggingRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            body = await request.body()
            if body:
               logger.info(...)  # log request with body
            else:
               logger.info(...)  # log request without body
            try:

                return await original_route_handler(request)
            except RequestValidationError as exc:
               detail = {"errors": exc.errors(), "body": body.decode()}
               raise HTTPException(status_code=422, detail=detail)

        return custom_route_handler
于 2022-02-13T14:36:40.410 回答