Serverless WebSockets

Mischa Spiegelmock 6 min read

WebSockets, the standard for doing real-time bidirectional communication typically between a browser and a server, is a fair attempt to create a standard to supplant the previously employed hacky solutions and continues to evolve in terms of implementation.

The basic idea has primarily been to establish some sort of channel in which a server can “push” events to a client, rather than the client “polling” every so often to see if there is new information. This was until fairly recently a relatively obscure concept, but now any smartphone owner is extremely well-acquainted with push notifications. This real-time channel has been used for not just notifications but also services like VOIP and gaming.

In the days before the WebSocket standard various semi-clever attempts to implement push notifications were devised. The first was using <iframe>s to load an HTML document using chunked encoding, where the server would write a script tag with some new data in the form of JavaScript commands when the data became available. When the browser encountered a closing script tag it would execute the JS immediately even though the document was still streaming.

The next scheme was using XML HTTP Request (aka XHR [aka AJAX]) to do something similar but without needing an <iframe>. This was known as “long-polling”, or “comet.” This was still mostly a unidirectional channel and suffered from timeouts and reconnection issues with potential race conditions.

Now with WebSockets we have a much improved system and wide browser support. But what about the backend? What happens when a browser or other client connects to a WebSocket server?

Previously we’ve developed and hosted WebSocket servers written in Perl, Go, and Python, using PostgreSQL asynchronous events as the message passing system. Deploying WebSocket servers is not as straightforward as HTTP servers because of the long-lived connections and having to perform TCP load balancing. Depending on your hosting setup you may have to deal with internal timeouts or getting events from your message bus to the right backend via some subscription mechanism.

Architecture

Since I love not running servers I’ve been excited about the chance to use serverless WebSockets via AWS API Gateway. In this new scheme you define Lambda functions that react to events such as authentication, connect, disconnect, and user-defined events that can be read from JSON message bodies.

Infrastructure-wise the setup is extremely basic. All of the real work to handle authorization and events and done in code, which we will look at shortly. Let’s use a concrete example of a typical WebSocket use case – sending notifications from the server to the client to inform it of some data change in order for the client to update some information in real time or notify the user.

For my application I created an authorizer function that validates a JWT encoded in the WebSocket URL query parameters (there is no good way in a browser to set headers when opening a WebSocket connection). This function denies or grants access to proceed and saves the authenticated user ID in the principalId response field, which is passed along to subsequent event handlers.

Once the authorization check is successful the special $connect route is called if there is a handler defined. In this handler we have the user ID in the invocation event passed along from the authorizer response and we have a connectionId. We save this user ID and connection ID pair in our database so that we can know who is connected and have the ability to send them a notification later on using their connectionId.

The API Gateway makes a best-effort attempt to detect disconnections and invokes the special $disconnect route whereupon our handler removes the connection record from the database.

Putting all of these pieces together with actual working code required me gathering a fair bit of information from different sources and working out the proper request fields and response formats but it all worked out wonderfully in the end. I’d like to share the working code examples for the handlers and some sample client code as well.

The Code

To define your handlers and when they get invoked you need to configure API Gateway to register your authorizer handler and the assorted route handlers. Using the Serverless toolkit this is straightforward and nicely documented. My configuration looks something like:

functions:
  # websocket authorizer
  wsAuth:
    handler: notifier.ws.handler.authorizer

  # websocket $connect
  wsConnect:
    handler: notifier.ws.handler.connect
    events:
      - websocket:
          route: $connect
          authorizer:
            name: wsAuth
            identitySource:
              - route.request.querystring.token  # token query param

  # websocket $disconnect
  wsDisconnect:
    handler: notifier.ws.handler.disconnect
    events:
      - websocket:
          route: $disconnect

And the authorizer:

def authorizer(event, context):
    method_arn = event.get("methodArn")
    def deny(msg):
        return {"message": msg,
                "policyDocument": gen_policy(method_arn=method_arn, allow=False)
        }

    # get access token from query string
    query_params = event.get("queryStringParameters")
    if not query_params:
        return deny("missing queryStringParameters")
    if "token" not in query_params:
        return deny("missing token in query string")
    token = query_params["token"]
    if not token:
        return deny("empty token")

    # decode and verify JWT token
    decoded = None
    try:
        decoded = decode_token(token)
    except ExpiredSignatureError:
        return deny("Expired token")

    identity = decoded.get("identity")
    if not identity:
        raise Exception("invalid JWT; missing identity")

    # allow access
    policy = gen_policy(method_arn=method_arn, allow=True)
    context = {}  # can add more auth context info here if desired
    res = {
        "principalId": identity,
        "policyDocument": policy,
        "context": context
    }
    return res

def gen_policy(method_arn: str, allow: bool):
    effect = "Allow" if allow else "Deny"
    return {
        "Version": "2012-10-17",
        "Statement": [{
            "Action": "execute-api:Invoke",
            "Effect": effect,
            "Resource": method_arn
        }],
    }

This looks for a JWT in the query string and attempts to parse and validate it. If successful then an IAM policy is returned along with the decoded identity ID. The details of the event and policy can be found in the Lambda REQUEST WebSocket authorizer documentation.

If the client is granted Invoke access to the execute-api service then API Gateway will call our $connect route next:

def connect(event, context):
    ctx = event.get("requestContext", {})
    # get user and connection id
    conn_id = ctx.get("connectionId")
    auth = ctx.get("authorizer", {})
    user_id = auth.get("principalId")

    if not user_id:
        return make_response(401, "Not authorized")

    if not conn_id:
        raise Exception("missing connectionId")

    # save the connection id/user id pair in DB
    WebsocketClient.save_connection(
        user_id=user_id,
        connection_id=conn_id,
        domain_name=ctx["domainName"],
        stage=ctx["stage"],
    )
    db.session.commit()

    return make_response(200, "ok")

def make_response(status_code, body):
    if not isinstance(body, str):
        body = json.dumps(body)
    return {"statusCode": status_code, "body": body}

The purpose of this route is to store the user ID and connection ID in the database along with the connection’s domain and stage. We will use this to send our notification to the client.

def send_ws(user_id, message):
    """Push a notification to the user if they have an active websocket connection."""
    connections = WebsocketClient \
        .query \
        .filter_by(user_id=user_id) \
        .all()

    for conn in connections:
        conn.send(message)

And conn.send():

import boto3
import json
from notifier.db import db, Model
from botocore.exceptions import ClientError

class WebsocketClient(Model):

    ...

    def send(self, message):
        """Send a message to an active connection.

        :param message: can be anything that is JSON-serializable."""
        # get APIGW management client
        apigw_mgmt_client = boto3.client(
            "apigatewaymanagementapi",
            endpoint_url=f"https://{self.domain_name}/{self.stage}",
        )
        try:
            # send message
            apigw_mgmt_client.post_to_connection(
                Data=json.dumps(message).encode("utf-8"),
                ConnectionId=self.connection_id,
            )
        except ClientError as err:
            # gracefully handle case where client is no longer connected
            code = int(err.response["Error"]["Code"])
            if code == 410:
                # client gone, cleanup
                db.session.delete(self)
                db.session.commit()
                return
            raise

This is the where the real action happens. When we want to send a message from the server to the client we do it with the PostToConnection call. We need to provide the API Gateway domain and stage for it to construct the URL needed for the API call. Boto is simply doing HTTP requests to interact with the WebSocket connection as documented here. And you can use an HTTP client directly if you like to get connection info, send a message, and close the connection.

For completeness let’s look at handling the $disconnect route:

def disconnect(event, context):
    # get connection ID
    ctx = event.get("requestContext", {})
    conn_id = ctx.get("connectionId")
    if not conn_id:
        raise Exception("no connection id found")

    # delete the connection record from our DB
    WebsocketClient.delete_connection(connection_id=conn_id)
    db.session.commit()
    return make_response(200, "ok")

Client ➞ Server Messages

But wait, there’s more!

Our application is now ready to send notifications to our client, but if we want to be able to receive messages from the client we can support this case as well. We can define custom routes that are matched based on a route key as documented here and here. In practice this means that if API Gateway receives a JSON message it looks for the route name by default in a field called "action" and decides which Lambda to call based on that value. You can also create a $default route to catch any unhandled message if you prefer to do things that way as well.

Client Code

I implemented a basic WebSocket client in TypeScript using the standard WebSocket API. The only special thing it does is append your access token (managed with axios-jwt) to the WebSocket connection URL.

import { refreshTokenIfNeeded } from 'axios-jwt'

export const WEBSOCKET_EVENT = 'onwebsocketmessage'

export class WSEvent extends Event {
  message: object

  constructor(msg: object) {
    super(WEBSOCKET_EVENT)
    this.message = msg
  }
}

export type WSEventHandler = (ev: WSEvent) => void

export default class WSClient extends EventTarget {
  ws: WebSocket | undefined
  public isConnected: boolean = false
  reconnectTime: number = 1 // time in seconds before reconnect

  // connect
  public open = async () => {
    if (this.ws) {
      if (this.ws.readyState === WebSocket.CONNECTING || this.ws.readyState === WebSocket.OPEN)
        // already open/opening
        return

      this.ws.close() // do reconnect
    }

    // config from create-react-app+dotenv
    if (!process.env.REACT_APP_WS_URL) throw new Error('REACT_APP_WS_URL missing')
    const host = new URL(process.env.REACT_APP_WS_URL)

    // make sure auth token is fresh
    // requestRefresh defined elsewhere - see axios-jwt documentation
    const accessToken = await refreshTokenIfNeeded(requestRefresh)

    // add auth token to URL
    if (accessToken) host.searchParams.set('token', accessToken)

    // create new websocket client
    if (!this.ws) {
      this.ws = new WebSocket(String(host))
      this.ws.onopen = this.handleOpen
      this.ws.onclose = this.handleClose
      this.ws.onmessage = this.handleMessage
    }
  }

  // disconnect
  public close = () => {
    if (this.ws) this.ws.close()
  }

  public reconnect() {
    if (this.ws) this.ws.close()
    this.open()
  }

  // CALLBACKS

  protected handleOpen = (ev: Event) => {
    this.isConnected = true
    this.reconnectTime = 1 // reset reconnect timer

    const ws = this.ws
    if (!ws) return
  }

  protected handleClose = (ev: Event) => {
    this.isConnected = false

    // do reconnect
    setTimeout(() => {
      this.reconnectTime *= 2 // exponential backoff

      this.open()
    }, this.reconnectTime * 1000)

    // reconnect?
    this.open()
  }

  protected handleMessage = (ev: MessageEvent) => {
    // handle message received on WS
    const data = ev.data
    if (!data) return

    // try to parse as JSON
    const msg = JSON.parse(data)

    // create new websocket event and dispatch it to listeners
    const msgEvt = new WSEvent(msg)
    this.dispatchEvent(msgEvt)
  }
}

And as a bonus here’s a React hook that lets you register an event handler for WebSocket messages:

import * as React from 'react'
import WSClient, { WEBSOCKET_EVENT, WSEvent } from './api'

// singleton
let client: WSClient

interface IUseWebSocketClientArgs {
  onEvent?: (evt: WSEvent) => void
}

const useWebSocketClient = ({ onEvent }: IUseWebSocketClientArgs) => {
  React.useEffect(() => {
    if (!client) client = new WSClient()

    // listen for events
    if (onEvent) client.addEventListener(WEBSOCKET_EVENT, onEvent as EventListener)

    // ensure client is connected
    client.open()

    // cleanup handler
    return () => {
      if (onEvent) client.removeEventListener(WEBSOCKET_EVENT, onEvent as EventListener)
    }
  })
  return { client }
}

export default useWebSocketClient

Conclusion

Like many other serverless technologies this approach is certainly not practical for every use case but it is quite reasonable for a lot of common cases. While API Gateway WebSockets kind of support binary data payloads the serverless approach is probably best suited to your application if you’re passing occasional JSON messages around and dealing with relatively low throughput and volume.