Skip to content

Gateway

This file outlines the interaction between naff and Discord's Gateway API.

class GatewayClient (WebsocketClient)

Abstraction over one gateway connection.

Multiple WebsocketClient instances can be used to implement same-process sharding.

Attributes:

Name Type Description
sequence

The sequence of this connection

session_id

The session ID of this connection

Source code in naff/api/gateway/gateway.py
class GatewayClient(WebsocketClient):
    """
    Abstraction over one gateway connection.

    Multiple `WebsocketClient` instances can be used to implement same-process sharding.

    Attributes:
        sequence: The sequence of this connection
        session_id: The session ID of this connection

    """

    def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
        super().__init__(state)

        self.shard = shard

        self.chunk_cache = {}

        self._trace = []
        self.sequence = None
        self.session_id = None

        self.ws_url = state.gateway_url
        self.ws_resume_url = MISSING

        # This lock needs to be held to send something over the gateway, but is also held when
        # reconnecting. That way there's no race conditions between sending and reconnecting.
        self._race_lock = asyncio.Lock()
        # Then this event is used so that receive() can wait for the reconnecting to complete.
        self._closed = asyncio.Event()

        self._keep_alive = None
        self._kill_bee_gees = asyncio.Event()
        self._last_heartbeat = 0
        self._acknowledged = asyncio.Event()
        self._acknowledged.set()  # Initialize it as set

        self._ready = asyncio.Event()
        self._close_gateway = asyncio.Event()

        # Sanity check, it is extremely important that an instance isn't reused.
        self._entered = False

    async def __aenter__(self: SELF) -> SELF:
        if self._entered:
            raise RuntimeError("An instance of 'WebsocketClient' cannot be re-used!")

        self._entered = True
        self._zlib = zlib.decompressobj()

        self.ws = await self.state.client.http.websocket_connect(self.state.gateway_url)

        hello = await self.receive(force=True)
        self.heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
        self._closed.set()

        self._keep_alive = asyncio.create_task(self.run_bee_gees())

        await self._identify()

        return self

    async def __aexit__(
        self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None
    ) -> None:
        # Technically should not be possible in any way, but might as well be safe worst-case.
        self._close_gateway.set()

        try:
            if self._keep_alive is not None:
                self._kill_bee_gees.set()
                try:
                    # Even if we get cancelled that is fine, because then the keep-alive
                    # handler will also be cancelled since we're waiting on it.
                    await self._keep_alive  # Wait for the keep-alive handler to finish
                finally:
                    self._keep_alive = None
        finally:
            if self.ws is not None:
                # We could be cancelled here, it is extremely important that we close the
                # WebSocket either way, hence the try/except.
                try:
                    await self.ws.close(code=1000)
                finally:
                    self.ws = None

    @property
    def average_latency(self) -> float:
        """Get the average latency of the connection."""
        if self.latency:
            return sum(self.latency) / len(self.latency)
        else:
            return float("inf")

    async def run(self) -> None:
        """Start receiving events from the websocket."""
        while True:
            if self._stopping is None:
                self._stopping = asyncio.create_task(self._close_gateway.wait())
            receiving = asyncio.create_task(self.receive())
            done, _ = await asyncio.wait({self._stopping, receiving}, return_when=asyncio.FIRST_COMPLETED)

            if receiving in done:
                # Note that we check for a received message first, because if both completed at
                # the same time, we don't want to discard that message.
                msg = await receiving
            else:
                # This has to be the stopping task, which we join into the current task (even
                # though that doesn't give any meaningful value in the return).
                await self._stopping
                receiving.cancel()
                return

            op = msg.get("op")
            data = msg.get("d")
            seq = msg.get("s")
            event = msg.get("t")

            if seq:
                self.sequence = seq

            if op == OPCODE.DISPATCH:
                asyncio.create_task(self.dispatch_event(data, seq, event))
                continue

            # This may try to reconnect the connection so it is best to wait
            # for it to complete before receiving more - that way there's less
            # possible race conditions to consider.
            await self.dispatch_opcode(data, op)

    async def dispatch_opcode(self, data, op: OPCODE) -> None:
        match op:

            case OPCODE.HEARTBEAT:
                logger.debug("Received heartbeat request from gateway")
                return await self.send_heartbeat()

            case OPCODE.HEARTBEAT_ACK:
                self.latency.append(time.perf_counter() - self._last_heartbeat)

                if self._last_heartbeat != 0 and self.latency[-1] >= 15:
                    logger.warning(
                        f"High Latency! shard ID {self.shard[0]} heartbeat took {self.latency[-1]:.1f}s to be acknowledged!"
                    )
                else:
                    logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds")

                return self._acknowledged.set()

            case OPCODE.RECONNECT:
                logger.debug("Gateway requested reconnect. Reconnecting...")
                return await self.reconnect(resume=True, url=self.ws_resume_url)

            case OPCODE.INVALIDATE_SESSION:
                logger.warning("Gateway has invalidated session! Reconnecting...")
                return await self.reconnect()

            case _:
                return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")

    async def dispatch_event(self, data, seq, event) -> None:
        match event:
            case "READY":
                self._ready.set()
                self._trace = data.get("_trace", [])
                self.sequence = seq
                self.session_id = data["session_id"]
                self.ws_resume_url = (
                    f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream"
                )
                logger.info(f"Shard {self.shard[0]} has connected to gateway!")
                logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
                # todo: future polls, improve guild caching here. run the debugger. you'll see why
                return self.state.client.dispatch(events.WebsocketReady(data))

            case "RESUMED":
                logger.info(f"Successfully resumed connection! Session_ID: {self.session_id}")
                self.state.client.dispatch(events.Resume())
                return

            case "GUILD_MEMBERS_CHUNK":
                asyncio.create_task(self._process_member_chunk(data.copy()))

            case _:
                # the above events are "special", and are handled by the gateway itself, the rest can be dispatched
                event_name = f"raw_{event.lower()}"
                processor = self.state.client.processors.get(event_name)
                if processor:
                    try:
                        asyncio.create_task(processor(events.RawGatewayEvent(data.copy(), override_name=event_name)))
                    except Exception as ex:
                        logger.error(f"Failed to run event processor for {event_name}: {ex}")
                else:
                    logger.debug(f"No processor for `{event_name}`")

        self.state.client.dispatch(events.RawGatewayEvent(data.copy(), override_name="raw_gateway_event"))
        self.state.client.dispatch(events.RawGatewayEvent(data.copy(), override_name=f"raw_{event.lower()}"))

    def close(self) -> None:
        """Shutdown the websocket connection."""
        self._close_gateway.set()

    async def _identify(self) -> None:
        """Send an identify payload to the gateway."""
        if self.ws is None:
            raise RuntimeError
        payload = {
            "op": OPCODE.IDENTIFY,
            "d": {
                "token": self.state.client.http.token,
                "intents": self.state.intents,
                "shard": self.shard,
                "large_threshold": 250,
                "properties": {"os": sys.platform, "browser": "naff", "device": "naff"},
                "presence": self.state.presence,
            },
            "compress": True,
        }

        serialized = OverriddenJson.dumps(payload)
        await self.ws.send_str(serialized)

        logger.debug(
            f"Shard ID {self.shard[0]} has identified itself to Gateway, requesting intents: {self.state.intents}!"
        )

    async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None:
        self.state.clear_ready()
        self._ready.clear()
        await super().reconnect(resume=resume, code=code, url=url)

    async def _resume_connection(self) -> None:
        """Send a resume payload to the gateway."""
        if self.ws is None:
            raise RuntimeError

        payload = {
            "op": OPCODE.RESUME,
            "d": {"token": self.state.client.http.token, "seq": self.sequence, "session_id": self.session_id},
        }

        serialized = OverriddenJson.dumps(payload)
        await self.ws.send_str(serialized)

        logger.debug(f"{self.shard[0]} is attempting to resume a connection")

    async def send_heartbeat(self) -> None:
        await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True)
        logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")

    async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None:
        """Update the bot's presence status."""
        payload = dict_filter_none(
            {
                "since": int(since if since else time.time() * 1000),
                "activities": [activity] if activity else [],
                "status": status,
                "afk": False,
            }
        )
        await self.send_json({"op": OPCODE.PRESENCE, "d": payload})

    async def request_member_chunks(
        self,
        guild_id: "Snowflake_Type",
        query="",
        *,
        limit,
        user_ids=None,
        presences=False,
        nonce=None,
    ) -> None:
        payload = {
            "op": OPCODE.REQUEST_MEMBERS,
            "d": dict_filter_none(
                {
                    "guild_id": guild_id,
                    "presences": presences,
                    "limit": limit,
                    "nonce": nonce,
                    "user_ids": user_ids,
                    "query": query,
                }
            ),
        }
        await self.send_json(payload)

    async def _process_member_chunk(self, chunk: dict) -> None:

        guild = self.state.client.cache.get_guild(to_snowflake(chunk.get("guild_id")))
        if guild:
            return asyncio.create_task(guild.process_member_chunk(chunk))
        raise ValueError(f"No guild exists for {chunk.get('guild_id')}")

    async def voice_state_update(
        self, guild_id: "Snowflake_Type", channel_id: "Snowflake_Type", muted: bool = False, deafened: bool = False
    ) -> None:
        """Update the bot's voice state."""
        payload = {
            "op": OPCODE.VOICE_STATE,
            "d": {"guild_id": guild_id, "channel_id": channel_id, "self_mute": muted, "self_deaf": deafened},
        }
        await self.send_json(payload)

async inherited method send(self, data, bypass)

Send data to the websocket.

Parameters:

Name Type Description Default
data str

The data to send

required
bypass

Should the rate limit be ignored for this send (used for heartbeats)

False
Source code in naff/api/gateway/gateway.py
async def send(self, data: str, bypass=False) -> None:
    """
    Send data to the websocket.

    Args:
        data: The data to send
        bypass: Should the rate limit be ignored for this send (used for heartbeats)

    """
    logger.debug(f"Sending data to websocket: {data}")

    async with self._race_lock:
        if self.ws is None:
            return logger.warning("Attempted to send data while websocket is not connected!")
        if not bypass:
            await self.rl_manager.rate_limit()

        await self.ws.send_str(data)

property readonly average_latency: float

Get the average latency of the connection.

async method run(self)

Start receiving events from the websocket.

Source code in naff/api/gateway/gateway.py
async def run(self) -> None:
    """Start receiving events from the websocket."""
    while True:
        if self._stopping is None:
            self._stopping = asyncio.create_task(self._close_gateway.wait())
        receiving = asyncio.create_task(self.receive())
        done, _ = await asyncio.wait({self._stopping, receiving}, return_when=asyncio.FIRST_COMPLETED)

        if receiving in done:
            # Note that we check for a received message first, because if both completed at
            # the same time, we don't want to discard that message.
            msg = await receiving
        else:
            # This has to be the stopping task, which we join into the current task (even
            # though that doesn't give any meaningful value in the return).
            await self._stopping
            receiving.cancel()
            return

        op = msg.get("op")
        data = msg.get("d")
        seq = msg.get("s")
        event = msg.get("t")

        if seq:
            self.sequence = seq

        if op == OPCODE.DISPATCH:
            asyncio.create_task(self.dispatch_event(data, seq, event))
            continue

        # This may try to reconnect the connection so it is best to wait
        # for it to complete before receiving more - that way there's less
        # possible race conditions to consider.
        await self.dispatch_opcode(data, op)

async inherited method send_json(self, data, bypass)

Send JSON data to the websocket.

Parameters:

Name Type Description Default
data dict

The data to send

required
bypass

Should the rate limit be ignored for this send (used for heartbeats)

False
Source code in naff/api/gateway/gateway.py
async def send_json(self, data: dict, bypass=False) -> None:
    """
    Send JSON data to the websocket.

    Args:
        data: The data to send
        bypass: Should the rate limit be ignored for this send (used for heartbeats)

    """
    serialized = OverriddenJson.dumps(data)
    await self.send(serialized, bypass)

async inherited method receive(self, force)

Receive a full event payload from the WebSocket.

Parameters:

Name Type Description Default
force bool

Whether to force the receiving, ignoring safety measures such as the read-lock. This option also means that exceptions are raised when a reconnection would normally be tried.

False
Source code in naff/api/gateway/gateway.py
async def receive(self, force: bool = False) -> str:
    """
    Receive a full event payload from the WebSocket.

    Args:
        force:
            Whether to force the receiving, ignoring safety measures such as the read-lock.
            This option also means that exceptions are raised when a reconnection would normally
            be tried.

    """
    buffer = bytearray()

    while True:
        if not force:
            # If we are currently reconnecting in another task, wait for it to complete.
            await self._closed.wait()

        resp = await self.ws.receive()

        if resp.type == WSMsgType.CLOSE:
            logger.debug(f"Disconnecting from gateway! Reason: {resp.data}::{resp.extra}")
            if resp.data >= 4000:
                # This should propagate to __aexit__() which will forcefully shut down everything
                # and cleanup correctly.
                raise WebSocketClosed(resp.data)

            if force:
                raise RuntimeError("Discord unexpectedly wants to close the WebSocket during force receive!")

            await self.reconnect(code=resp.data, resume=resp.data != 1000)
            continue

        elif resp.type is WSMsgType.CLOSED:
            if force:
                raise RuntimeError("Discord unexpectedly closed the underlying socket during force receive!")

            if not self._closed.is_set():
                # Because we are waiting for the even before we receive, this shouldn't be
                # possible - the CLOSING message should be returned instead. Either way, if this
                # is possible after all we can just wait for the event to be set.
                await self._closed.wait()
            else:
                # This is an odd corner-case where the underlying socket connection was closed
                # unexpectedly without communicating the WebSocket closing handshake. We'll have
                # to reconnect ourselves.
                await self.reconnect(resume=True)

        elif resp.type is WSMsgType.CLOSING:
            if force:
                raise RuntimeError("WebSocket is unexpectedly closing during force receive!")

            # This happens when the keep-alive handler is reconnecting the connection even
            # though we waited for the event before hand, because it got to run while we waited
            # for data to come in. We can just wait for the event again.
            await self._closed.wait()
            continue

        if resp.data is None:
            continue

        if isinstance(resp.data, bytes):
            buffer.extend(resp.data)

            if len(resp.data) < 4 or resp.data[-4:] != b"\x00\x00\xff\xff":
                # message isn't complete yet, wait
                continue

            msg = self._zlib.decompress(buffer)
            msg = msg.decode("utf-8")
        else:
            msg = resp.data

        try:
            msg = OverriddenJson.loads(msg)
        except Exception as e:
            logger.error(e)
            continue

        return msg

method close(self)

Shutdown the websocket connection.

Source code in naff/api/gateway/gateway.py
def close(self) -> None:
    """Shutdown the websocket connection."""
    self._close_gateway.set()

async method send_heartbeat(self)

Send a heartbeat to the gateway.

Source code in naff/api/gateway/gateway.py
async def send_heartbeat(self) -> None:
    await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True)
    logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")

async method change_presence(self, activity, status, since)

Update the bot's presence status.

Source code in naff/api/gateway/gateway.py
async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None:
    """Update the bot's presence status."""
    payload = dict_filter_none(
        {
            "since": int(since if since else time.time() * 1000),
            "activities": [activity] if activity else [],
            "status": status,
            "afk": False,
        }
    )
    await self.send_json({"op": OPCODE.PRESENCE, "d": payload})

async method voice_state_update(self, guild_id, channel_id, muted, deafened)

Update the bot's voice state.

Source code in naff/api/gateway/gateway.py
async def voice_state_update(
    self, guild_id: "Snowflake_Type", channel_id: "Snowflake_Type", muted: bool = False, deafened: bool = False
) -> None:
    """Update the bot's voice state."""
    payload = {
        "op": OPCODE.VOICE_STATE,
        "d": {"guild_id": guild_id, "channel_id": channel_id, "self_mute": muted, "self_deaf": deafened},
    }
    await self.send_json(payload)