Introduction

The middleware is used to extend the event framework, add custom functionality, and provides important functionality unrelated to the logic of the main handler. For example, retrying the handler after returning an error, or recovering from panic and capturing the stack trace within the handler.

The middleware function signature is defined as follows:

Full source code: github.com/ThreeDotsLabs/watermill/message/router.go

// ...
// HandlerMiddleware allows us to write decorators similar to the handler.
// It can execute some operations before the handler (e.g., modify the consumed message)
// and also perform some operations after the handler (modify the produced message, ACK/NACK the consumed message, handle errors, logging, etc.).
//
// It can be attached to the router by using the `AddMiddleware` method.
//
// Example:
//
//	func ExampleMiddleware(h message.HandlerFunc) message.HandlerFunc {
//		return func(message *message.Message) ([]*message.Message, error) {
//			fmt.Println("Before executing the handler")
//			producedMessages, err := h(message)
//			fmt.Println("After executing the handler")
//
//			return producedMessages, err
//		}
//	}
type HandlerMiddleware func(h HandlerFunc) HandlerFunc
// ...

Usage

Middleware can be applied to all handlers in the router or to specific handlers. When middleware is added directly to the router, it will be applied to all handlers provided for the router. If a middleware is only applied to a specific handler, it needs to be added to the handler in the router.

Here's an example of usage:

Full Source Code: github.com/ThreeDotsLabs/watermill/_examples/basic/3-router/main.go

// ...
	router, err := message.NewRouter(message.RouterConfig{}, logger)
	if err != nil {
		panic(err)
	}

	// When receiving the SIGTERM signal, SignalsHandler will gracefully close the router.
	// You can also close the router by calling `r.Close()`.
	router.AddPlugin(plugin.SignalsHandler)

	// Router-level middleware will be executed on every message sent to the router
	router.AddMiddleware(
		// CorrelationID will copy the correlation ID from the incoming message metadata to the generated message
		middleware.CorrelationID,

		// If the handler returns an error, it will be retried.
		// It will be retried at most MaxRetries times, after which the message will be Nacked and re-sent by PubSub.
		middleware.Retry{
			MaxRetries:      3,
			InitialInterval: time.Millisecond * 100,
			Logger:          logger,
		}.Middleware,

		// Recoverer handles panics in the handler.
		// In this case, it passes them as errors to the Retry middleware.
		middleware.Recoverer,
	)

	// For simplicity, we use gochannel Pub/Sub here,
	// you can replace it with any Pub/Sub implementation and it will work the same.
	pubSub := gochannel.NewGoChannel(gochannel.Config{}, logger)

	// Publish some incoming messages in the background
	go publishMessages(pubSub)

	// AddHandler returns a handler that can be used to add handler-level middleware
	// or to stop the handler.
	handler := router.AddHandler(
		"struct_handler",          // Handler name, must be unique
		"incoming_messages_topic", // Topic from which events will be read
		pubSub,
		"outgoing_messages_topic", // Topic to which events will be published
		pubSub,
		structHandler{}.Handler,
	)

	// Handler-level middleware is only executed for specific handlers
	// Such middleware can be added to the handler in the same way as router-level middleware
	handler.AddMiddleware(func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			log.Println("Executing handler-specific middleware for", message.UUID)

			return h(message)
		}
	})

	// For debugging purposes only, we print all messages received on `incoming_messages_topic`
	router.AddNoPublisherHandler(
		"print_incoming_messages",
		"incoming_messages_topic",
		pubSub,
		printMessages,
	)

	// For debugging purposes only, we print all events sent to `outgoing_messages_topic`
	router.AddNoPublisherHandler(
		"print_outgoing_messages",
		"outgoing_messages_topic",
		pubSub,
		printMessages,
	)

	// Now that all handlers have been registered, we can run the router.
	// Run will block until the router stops running.
// ...

Available Middleware

Here are the reusable middlewares provided by Watermill, and you can also easily implement your own middleware. For example, if you want to store each incoming message in a certain type of log format, this is the best way to do it.

Circuit Breaker

// CircuitBreaker is a middleware that wraps the handler in a circuit breaker.
// Based on the configuration, the circuit breaker will fast fail if the handler continues to return errors.
// This is useful for preventing cascading failures.
type CircuitBreaker struct {
    cb *gobreaker.CircuitBreaker
}
// NewCircuitBreaker returns a new CircuitBreaker middleware.
// For available settings, please refer to the gobreaker documentation.
func NewCircuitBreaker(settings gobreaker.Settings) CircuitBreaker {
    return CircuitBreaker{
        cb: gobreaker.NewCircuitBreaker(settings),
    }
}
// Middleware returns the CircuitBreaker middleware.
func (c CircuitBreaker) Middleware(h message.HandlerFunc) message.HandlerFunc {
    return func(msg *message.Message) ([]*message.Message, error) {
        out, err := c.cb.Execute(func() (interface{}, error) {
            return h(msg)
        })

        var result []*message.Message
        if out != nil {
            result = out.([]*message.Message)
        }

        return result, err
    }
}

Correlation

// SetCorrelationID sets the correlation ID for the message.
//
// When a message enters the system, SetCorrelationID should be called.
// When a message is generated in a request (e.g., HTTP), the message's correlation ID should be the same as the request's correlation ID.
func SetCorrelationID(id string, msg *message.Message) {
    if MessageCorrelationID(msg) != "" {
        return
    }

    msg.Metadata.Set(CorrelationIDMetadataKey, id)
}
// MessageCorrelationID returns the correlation ID from the message.
func MessageCorrelationID(message *message.Message) string {
    return message.Metadata.Get(CorrelationIDMetadataKey)
}
// CorrelationID adds a correlation ID to all messages generated by the handler.
// The ID is based on the message ID received by the handler.
//
// In order for CorrelationID to work correctly, SetCorrelationID must be called first for the message to enter the system.
func CorrelationID(h message.HandlerFunc) message.HandlerFunc {
    return func(message *message.Message) ([]*message.Message, error) {
        producedMessages, err := h(message)

        correlationID := MessageCorrelationID(message)
        for _, msg := range producedMessages {
            SetCorrelationID(correlationID, msg)
        }

        return producedMessages, err
    }
}

Duplicator

// Duplicator processes the message twice to ensure the endpoint is idempotent.
func Duplicator(h message.HandlerFunc) message.HandlerFunc {
    return func(msg *message.Message) ([]*message.Message, error) {
        firstProducedMessages, firstErr := h(msg)
        if firstErr != nil {
            return nil, firstErr
        }

        secondProducedMessages, secondErr := h(msg)
        if secondErr != nil {
            return nil, secondErr
        }

        return append(firstProducedMessages, secondProducedMessages...), nil
    }
}

Ignore Errors

// IgnoreErrors provides a middleware that allows the handler to ignore certain explicitly defined errors.
type IgnoreErrors struct {
    ignoredErrors map[string]struct{}
}
// NewIgnoreErrors creates a new IgnoreErrors middleware.
func NewIgnoreErrors(errs []error) IgnoreErrors {
    errsMap := make(map[string]struct{}, len(errs))

    for _, err := range errs {
        errsMap[err.Error()] = struct{}{}
    }

    return IgnoreErrors{errsMap}
}
// Middleware returns the IgnoreErrors middleware.
func (i IgnoreErrors) Middleware(h message.HandlerFunc) message.HandlerFunc {
    return func(msg *message.Message) ([]*message.Message, error) {
        events, err := h(msg)
        if err != nil {
            if _, ok := i.ignoredErrors[errors.Cause(err).Error()]; ok {
                return events, nil
            }

            return events, err
        }

        return events, nil
    }
}

Instant Ack

// InstantAck makes the handler immediately acknowledge the incoming message, regardless of any errors.
// It can be used to improve throughput, but the trade-off is:
// If you need to ensure exactly-once delivery, you may get at least once delivery.
// If you require ordered messages, it may break the ordering.
func InstantAck(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		message.Ack()
		return h(message)
	}
}

Poison

// PoisonQueue provides a middleware feature to handle unprocessable messages and publishes them to a separate topic.
// Then, the main middleware chain continues to execute, and business proceeds as usual.
func PoisonQueue(pub message.Publisher, topic string) (message.HandlerMiddleware, error) {
	if topic == "" {
		return nil, ErrInvalidPoisonQueueTopic
	}

	pq := poisonQueue{
		topic: topic,
		pub:   pub,
		shouldGoToPoisonQueue: func(err error) bool {
			return true
		},
	}

	return pq.Middleware, nil
}

// PoisonQueueWithFilter is similar to PoisonQueue, but accepts a function to determine which errors meet the poison queue criteria.
func PoisonQueueWithFilter(pub message.Publisher, topic string, shouldGoToPoisonQueue func(err error) bool) (message.HandlerMiddleware, error) {
	if topic == "" {
		return nil, ErrInvalidPoisonQueueTopic
	}

	pq := poisonQueue{
		topic: topic,
		pub:   pub,
		shouldGoToPoisonQueue: shouldGoToPoisonQueue,
	}

	return pq.Middleware, nil
}

Random Fail

// RandomFail causes the handler to fail based on a random probability. The error probability should be within the range (0, 1).
func RandomFail(errorProbability float32) message.HandlerMiddleware {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			if shouldFail(errorProbability) {
				return nil, errors.New("a random error occurred")
			}
			return h(message)
		}
	}
}

// RandomPanic causes the handler to panic based on a random probability. The panic probability should be within the range (0, 1).
func RandomPanic(panicProbability float32) message.HandlerMiddleware {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			if shouldFail(panicProbability) {
				panic("a random panic occurred")
			}
			return h(message)
		}
	}
}

Recoverer

// RecoveredPanicError holds the recovered panic's error and its stack trace information.
type RecoveredPanicError struct {
	V          interface{}
	Stacktrace string
}

// Recoverer recovers any panic from the handler and attaches RecoveredPanicError with stack trace to any error returned from the handler.
func Recoverer(h message.HandlerFunc) message.HandlerFunc {
	return func(event *message.Message) (events []*message.Message, err error) {
		panicked := true

		defer func() {
			if r := recover(); r != nil || panicked {
				err = errors.WithStack(RecoveredPanicError{V: r, Stacktrace: string(debug.Stack())})
			}
		}()

		events, err = h(event)
		panicked = false
		return events, err
	}
}

Retry

// Retry provides a middleware that retries the handler if an error is returned.
// The retry behavior, exponential backoff, and maximum elapsed time can be configured.
type Retry struct {
	// MaxRetries is the maximum number of attempts to be made.
	MaxRetries int

	// InitialInterval is the initial interval between retries. Subsequent intervals will be scaled by the Multiplier.
	InitialInterval time.Duration
	// MaxInterval sets the upper limit for the exponential backoff of retries.
	MaxInterval time.Duration
	// Multiplier is the factor by which the wait interval between retries will be multiplied.
	Multiplier float64
	// MaxElapsedTime sets the maximum time limit for retries. If 0, it is disabled.
	MaxElapsedTime time.Duration
	// RandomizationFactor randomly spreads the backoff time within the following range:
	// [currentInterval * (1 - randomization_factor), currentInterval * (1 + randomization_factor)].
	RandomizationFactor float64

	// OnRetryHook is an optional function to be executed on each retry attempt.
	// The current retry number is passed through retryNum.
	OnRetryHook func(retryNum int, delay time.Duration)

	Logger watermill.LoggerAdapter
}
// Middleware returns the Retry middleware.
func (r Retry) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		producedMessages, err := h(msg)
		if err == nil {
			return producedMessages, nil
		}

		expBackoff := backoff.NewExponentialBackOff()
		expBackoff.InitialInterval = r.InitialInterval
		expBackoff.MaxInterval = r.MaxInterval
		expBackoff.Multiplier = r.Multiplier
		expBackoff.MaxElapsedTime = r.MaxElapsedTime
		expBackoff.RandomizationFactor = r.RandomizationFactor

		ctx := msg.Context()
		if r.MaxElapsedTime > 0 {
			var cancel func()
			ctx, cancel = context.WithTimeout(ctx, r.MaxElapsedTime)
			defer cancel()
		}

		retryNum := 1
		expBackoff.Reset()
	retryLoop:
		for {
			waitTime := expBackoff.NextBackOff()
			select {
			case 

Throttle

// Throttle provides a middleware to limit the number of messages processed within a certain time period.
// This can be used to prevent overloading handlers running on an unprocessed long queue.
type Throttle struct {
	ticker *time.Ticker
}
// NewThrottle creates a new Throttle middleware.
// Example duration and count: NewThrottle(10, time.Second) indicates 10 messages per second.
func NewThrottle(count int64, duration time.Duration) *Throttle {
	return &Throttle{
		ticker: time.NewTicker(duration / time.Duration(count)),
	}
}
// Middleware returns the Throttle middleware.
func (t Throttle) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		// Throttles shared by multiple handlers will wait for their "ticks".

Timeout

// Timeout cancels the incoming message's context after the specified duration.
// Any timeout-sensitive functionalities of the handler should listen to msg.Context().Done() to know when to fail.
func Timeout(timeout time.Duration) func(message.HandlerFunc) message.HandlerFunc {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(msg *message.Message) ([]*message.Message, error) {
			ctx, cancel := context.WithTimeout(msg.Context(), timeout)
			defer func() {
				cancel()
			}()

			msg.SetContext(ctx)
			return h(msg)
		}
	}
}