Documentation
tomodachi

Middleware functionality

Custom middleware for services

Middlewares can be used to add functionality to the service, for example to add logging, authentication, tracing, build more advanced logic for messaging, unpack request queries, modify HTTP responses, handle uncaught errors, add additional context to handlers, etc.

Custom middleware functions or objects that can be called are added to the service by specifying them as a list in the http_middleware and message_middleware attribute of the service class.

from .middleware import logger_middleware

class Service(tomodachi.Service):
    name = "middleware-example"
    http_middleware = [logger_middleware]
    ...

The middlewares are invoked in the stacked order they are specified in http_middleware or message_middleware with the first callable in the list to be called first (and then also return last).

Provided arguments to middleware functions

  1. The first unbound argument of a middleware function will receive the coroutine function to call next (which would be either the handlers function or a function for the next middleware in the chain).
    (recommended keyword argument name: func)
  2. (optional) The second unbound argument of a middleware function will receive the service class object.
    (recommended keyword argument name: service)
  3. (optional) The third unbound argument of a middleware function will receive the request object for HTTP middlewares, or the message (as parsed by the envelope) for message middlewares.
    (recommended keyword argument name: request or message)

Use the recommended names to prevent collisions with passed keywords for transport centric values that are also sent to the middleware if the keyword arguments are defined in the function signature.

Calling the handler or the next middleware in the chain

When calling the next function in the chain, the middleware function should be called as an awaitable function (await func()) and for HTTP middlewares the result from the function call should most commonly be returned.

Adding custom arguments passed on to the handler

The function can be called with any number of custom keyword arguments, which will then be passed to each following middleware and the handler itself. This pattern works a bit how contextvars can be set up, but could be useful for passing values and objects instead of keeping them in a global context.

async def logger_middleware(func: Callable[..., Awaitable], *, traceid: str = "") -> Any:
    if not traceid:
        traceid = uuid.uuid4().hex
    logger = Logger(traceid=traceid)

    # Passes the logger and traceid to following middlewares and to the handler
    return await func(logger=logger, traceid=traceid)

A middleware can only add new keywords or modify the values or existing keyword arguments (by passing it through again with the new value). The exception to this is that passed keywords for transport centric values will be ignored - their value cannot be modified - they will retain their original value.

While a middleware can modify the values of custom keyword arguments, there is no way for a middleware to completely remove any keyword that has been added by previous middlewares.

Example of a middleware specified as a function that adds tracing to AWS SQS handlers

This example portrays a middleware function which adds trace spans around the function, with the trace context populated from a "traceparent header" value collected from a SNS message' message attribute. The topic name and SNS message identifier is also added as attributes to the trace span.

async def trace_middleware(
    func: Callable[... Awaitable],
    *,
    topic: str,
    message_attributes: dict,
    sns_message_id: str
) -> None:
    ctx: Context | None = None

    if carrier_traceparent := message_attributes.get("telemetry.carrier.traceparent"):
        carrier: dict[str, list[str] | str] = {"traceparent": carrier_traceparent}
        ctx = TraceContextTextMapPropagator().extract(carrier=carrier)

    with tracer.start_as_current_span(f"SNSSQS handler '{func.__name__}'", context=ctx) as span:
        span.set_attribute("messaging.system", "AmazonSQS")
        span.set_attribute("messaging.operation", "process")
        span.set_attribute("messaging.source.name", topic)
        span.set_attribute("messaging.message.id", sns_message_id)

        try:
            # Calls the handler function (or next middleware in the chain)
            await func()
        except BaseException as exc:
            logging.getLogger("exception").exception(exc)
            span.record_exception(exc, escaped=True)
            span.set_status(StatusCode.ERROR, f"{exc.__class__.__name__}: {exc}")
            raise exc
from .middleware import trace_middleware
from .envelope import Event, MessageEnvelope

class Service(tomodachi.Service):
    name = "middleware-example"
    message_envelope: MessageEnvelope(key="event")
    message_middleware = [trace_middleware]

    @tomodachi.aws_sns_sqs("example-topic", queue_name="example-queue")
    async def handler(self, event: Event) -> None:
        ...

Example of a middleware specified as a class

A middleware can also be specified as the object of a class, in which case the __call__ method of the object will be invoked as the middleware function. Note that bound functions such as self has to be included in the signature as it's called as a normal class function.

This class provides a simplistic basic auth implementation validating credentials in the HTTP Authorization header for HTTP requests to the service.

class BasicAuthMiddleware:
    def __init__(self, username: str, password: str) -> None:
        self.valid_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()

    async def __call__(
        self,
        func: Callable[..., Awaitable[web.Response]],
        *,
        request: web.Request,
    ) -> web.Response:
        try:
            auth = request.headers.get("Authorization", "")
            encoded_credentials = auth.split()[-1] if auth.startswith("Basic ") else ""

            if encoded_credentials == self.valid_credentials:
                username = base64.b64decode(encoded_credentials).decode().split(":")[0]
                # Calls the handler function (or next middleware in the chain).
                # The handler (and following middlewares) can use username in their signature.
                return await func(username=username)
            elif auth:
                return web.json_response({"status": "bad credentials"}, status=401)

            return web.json_response({"status": "auth required"}, status=401)
        except BaseException as exc:
            try:
                logging.getLogger("exception").exception(exc)
                raise exc
            finally:
                return web.json_response({"status": "internal server error"}, status=500)
from .middleware import trace_middleware

class Service(tomodachi.Service):
    name = "middleware-example"
    http_middleware = [BasicAuthMiddleware(username="example", password="example")]

    @tomodachi.http("GET", r"/")
    async def handler(self, request: web.Request, username: str) -> web.Response:
        ...