attrs class BaseCommand (DictSerializationMixin, CallbackObject)

An object all commands inherit from. Outlines the basic structure of a command, and handles checks.


Name Type Description
extension Any

The extension this command belongs to.

enabled bool

Whether this command is enabled

checks list

Any checks that must be run before this command can be run

callback Callable[..., Coroutine]

The coroutine to be called for this command

error_callback Callable[..., Coroutine]

The coroutine to be called when an error occurs

pre_run_callback Callable[..., Coroutine]

A coroutine to be called before this command is run but after the checks

post_run_callback Callable[..., Coroutine]

A coroutine to be called after this command has run

Attr attributes:

Source code in naff/models/naff/
class BaseCommand(DictSerializationMixin, CallbackObject):
    extension: Any = field(default=None, metadata=docs("The extension this command belongs to") | no_export_meta)

    enabled: bool = field(default=True, metadata=docs("Whether this can be run at all") | no_export_meta)
    checks: list = field(
        factory=list, metadata=docs("Any checks that must be *checked* before the command can run") | no_export_meta
    cooldown: Cooldown = field(
        default=MISSING, metadata=docs("An optional cooldown to apply to the command") | no_export_meta
    max_concurrency: MaxConcurrency = field(
        metadata=docs("An optional maximum number of concurrent instances to apply to the command") | no_export_meta,

    callback: Callable[..., Coroutine] = field(
        default=None, metadata=docs("The coroutine to be called for this command") | no_export_meta
    error_callback: Callable[..., Coroutine] = field(
        default=None, metadata=no_export_meta | docs("The coroutine to be called when an error occurs")
    pre_run_callback: Callable[..., Coroutine] = field(
        | docs("The coroutine to be called before the command is executed, **but** after the checks"),
    post_run_callback: Callable[..., Coroutine] = field(
        default=None, metadata=no_export_meta | docs("The coroutine to be called after the command has executed")

    def __attrs_post_init__(self) -> None:
        if self.callback is not None:
            if hasattr(self.callback, "checks"):
                self.checks += self.callback.checks
            if hasattr(self.callback, "cooldown"):
                self.cooldown = self.callback.cooldown
            if hasattr(self.callback, "max_concurrency"):
                self.max_concurrency = self.callback.max_concurrency

    def __hash__(self) -> int:
        return id(self)

    async def __call__(self, context: "Context", *args, **kwargs) -> None:
        Calls this command.

            context: The context of this command
            args: Any
            kwargs: Any

        # signals if a semaphore has been acquired, for exception handling
        # if present assume one will be acquired
        max_conc_acquired = self.max_concurrency is not MISSING

            if await self._can_run(context):
                if self.pre_run_callback is not None:
                    await self.call_with_binding(self.pre_run_callback, context, *args, **kwargs)

                if self.extension is not None and self.extension.extension_prerun:
                    for prerun in self.extension.extension_prerun:
                        await prerun(context, *args, **kwargs)

                await self.call_callback(self.callback, context)

                if self.post_run_callback is not None:
                    await self.call_with_binding(self.post_run_callback, context, *args, **kwargs)

                if self.extension is not None and self.extension.extension_postrun:
                    for postrun in self.extension.extension_postrun:
                        await postrun(context, *args, **kwargs)

        except Exception as e:
            # if a MaxConcurrencyReached-exception is raised a connection was never acquired
            max_conc_acquired = not isinstance(e, MaxConcurrencyReached)

            if self.error_callback:
                await self.error_callback(e, context, *args, **kwargs)
            elif self.extension and self.extension.extension_error:
                await self.extension.extension_error(e, context, *args, **kwargs)
            if self.max_concurrency is not MISSING and max_conc_acquired:
                await self.max_concurrency.release(context)

    def _get_converter_function(anno: type[Converter] | Converter, name: str) -> Callable[[Context, str], Any]:
        num_params = len(get_parameters(anno.convert))

        # if we have three parameters for the function, it's likely it has a self parameter
        # so we need to get rid of it by initing - typehinting hates this, btw!
        # the below line will error out if we aren't supposed to init it, so that works out
            actual_anno: Converter = anno() if num_params == 3 else anno  # type: ignore
        except TypeError:
            raise ValueError(
                f"{get_object_name(anno)} for {name} is invalid: converters must have exactly 2 arguments."
            ) from None

        # we can only get to this point while having three params if we successfully inited
        if num_params == 3:
            num_params -= 1

        if num_params != 2:
            raise ValueError(
                f"{get_object_name(anno)} for {name} is invalid: converters must have exactly 2 arguments."

        return actual_anno.convert

    async def try_convert(self, converter: Optional[Callable], context: "Context", value: Any) -> Any:
        if converter is None:
            return value
        return await maybe_coroutine(converter, context, value)

    def param_config(self, annotation: Any, name: str) -> Tuple[Callable, Optional[dict]]:
        # This thing is complicated. NAFF-annotations can either be annotated directly, or they can be annotated with Annotated[str, CMD_*]
        # This helper function handles both cases, and returns a tuple of the converter and its config (if any)
        if annotation is None:
            return None
        if typing.get_origin(annotation) is Annotated and (args := typing.get_args(annotation)):
            for ann in args:
                v = getattr(ann, name, None)
                if v is not None:
                    return (ann, v)
        return (annotation, getattr(annotation, name, None))

    async def call_callback(self, callback: Callable, context: "Context") -> None:
        _call = callback
        if self.has_binding:
            callback = functools.partial(callback, None, None)
            callback = functools.partial(callback, None)
        parameters = get_parameters(callback)
        args = []
        kwargs = {}
        if len(parameters) == 0:
            # if no params, user only wants context
            return await self.call_with_binding(_call, context)

        c_args = copy.copy(context.args)
        for param in parameters.values():
            if isinstance(param.annotation, Converter):
                # for any future dev looking at this:
                # this checks if the class here has a convert function
                # it does NOT check if the annotation is actually a subclass of Converter
                # this is an intended behavior for Protocols with the runtime_checkable decorator
                convert = functools.partial(
                    self.try_convert, self._get_converter_function(param.annotation,, context
                convert = functools.partial(self.try_convert, None, context)
            func, config = self.param_config(param.annotation, "_annotation_dat")
            if config:
                # if user has used an naff-annotation, run the annotation, and pass the result to the user
                local = {"context": context, "extension": self.extension, "param":}
                ano_args = [local[c] for c in config["args"]]
                if param.kind != param.POSITIONAL_ONLY:
                    kwargs[] = func(*ano_args)
            elif in context.kwargs:
                # if parameter is in kwargs, user obviously wants it, pass it
                if param.kind != param.POSITIONAL_ONLY:
                    kwargs[] = await convert(context.kwargs[])
                    args.append(await convert(context.kwargs[]))
                if context.kwargs[] in c_args:
            elif param.default is not param.empty:
                kwargs[] = param.default
                if not str(param).startswith("*"):
                    if param.kind != param.KEYWORD_ONLY:
                            args.append(await convert(c_args.pop(0)))
                        except IndexError:
                            raise ValueError(
                                f"{context.invoke_target} expects {len([p for p in parameters.values() if p.default is p.empty]) + len(callback.args)}"
                                f" arguments but received {len(context.args)} instead"
                            ) from None
                        raise ValueError(f"Unable to resolve argument: {}")

        if any(kwargs_reg.match(str(param)) for param in parameters.values()):
            # if user has `**kwargs` pass all remaining kwargs
            kwargs = kwargs | {k: v for k, v in context.kwargs.items() if k not in kwargs}
        if any(args_reg.match(str(param)) for param in parameters.values()):
            # user has `*args` pass all remaining args
            args = args + [await convert(c) for c in c_args]
        return await self.call_with_binding(_call, context, *args, **kwargs)

    async def _can_run(self, context: Context) -> bool:
        Determines if this command can be run.

            context: The context of the command

        max_conc_acquired = False  # signals if a semaphore has been acquired, for exception handling

            if not self.enabled:
                return False

            for _c in self.checks:
                if not await _c(context):
                    raise CommandCheckFailure(self, _c, context)

            if self.extension and self.extension.extension_checks:
                for _c in self.extension.extension_checks:
                    if not await _c(context):
                        raise CommandCheckFailure(self, _c, context)

            if self.max_concurrency is not MISSING:
                if not await self.max_concurrency.acquire(context):
                    raise MaxConcurrencyReached(self, self.max_concurrency)

            if self.cooldown is not MISSING:
                if not await self.cooldown.acquire_token(context):
                    raise CommandOnCooldown(self, await self.cooldown.get_cooldown(context))

            return True

        except Exception:
            if max_conc_acquired:
                await self.max_concurrency.release(context)

    def error(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
        """A decorator to declare a coroutine as one that will be run upon an error."""
        if not asyncio.iscoroutinefunction(call):
            raise TypeError("Error handler must be coroutine")
        self.error_callback = call
        return call

    def pre_run(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
        """A decorator to declare a coroutine as one that will be run before the command."""
        if not asyncio.iscoroutinefunction(call):
            raise TypeError("pre_run must be coroutine")
        self.pre_run_callback = call
        return call

    def post_run(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
        """A decorator to declare a coroutine as one that will be run after the command has."""
        if not asyncio.iscoroutinefunction(call):
            raise TypeError("post_run must be coroutine")
        self.post_run_callback = call
        return call

function check(check)

Add a check to a command.


Name Type Description Default
check Callable[['Context'], Awaitable[bool]]

A coroutine as a check for this command

Source code in naff/models/naff/
def check(check: Callable[["Context"], Awaitable[bool]]) -> Callable[[Coroutine], Coroutine]:
    Add a check to a command.

        check: A coroutine as a check for this command


    def wrapper(coro: Coroutine) -> Coroutine:
        if isinstance(coro, BaseCommand):
            return coro
        if not hasattr(coro, "checks"):
            coro.checks = []
        return coro

    return wrapper

function cooldown(bucket, rate, interval)

Add a cooldown to a command.


Name Type Description Default
bucket Buckets

The bucket used to track cooldowns

rate int

How many commands may be ran per interval

interval float

How many seconds to wait for a cooldown

Source code in naff/models/naff/
def cooldown(bucket: Buckets, rate: int, interval: float) -> Callable[[Coroutine], Coroutine]:
    Add a cooldown to a command.

        bucket: The bucket used to track cooldowns
        rate: How many commands may be ran per interval
        interval: How many seconds to wait for a cooldown


    def wrapper(coro: Coroutine) -> Coroutine:
        cooldown_obj = Cooldown(bucket, rate, interval)

        coro.cooldown = cooldown_obj

        return coro

    return wrapper

function max_concurrency(bucket, concurrent)

Add a maximum number of concurrent instances to the command.


Name Type Description Default
bucket Buckets

The bucket to enforce the maximum within

concurrent int

The maximum number of concurrent instances to allow

Source code in naff/models/naff/
def max_concurrency(bucket: Buckets, concurrent: int) -> Callable[[Coroutine], Coroutine]:
    Add a maximum number of concurrent instances to the command.

        bucket: The bucket to enforce the maximum within
        concurrent: The maximum number of concurrent instances to allow


    def wrapper(coro: Coroutine) -> Coroutine:
        max_conc = MaxConcurrency(concurrent, bucket)

        coro.max_concurrency = max_conc

        return coro

    return wrapper