Cooldowns
attrs
class
BaseCommand (DictSerializationMixin, CallbackObject)
¶
An object all commands inherit from. Outlines the basic structure of a command, and handles checks.
Attributes:
Name | Type | Description |
---|---|---|
extension |
Any |
The extension this command belongs to. |
enabled |
bool |
Whether this command is enabled |
checks |
list |
Any checks that must be run before this command can be run |
callback |
Callable[..., Coroutine] |
The coroutine to be called for this command |
error_callback |
Callable[..., Coroutine] |
The coroutine to be called when an error occurs |
pre_run_callback |
Callable[..., Coroutine] |
A coroutine to be called before this command is run but after the checks |
post_run_callback |
Callable[..., Coroutine] |
A coroutine to be called after this command has run |
Attr attributes:
Name | Type | Description |
---|---|---|
extension |
Any |
The extension this command belongs to |
enabled |
bool |
Whether this can be run at all |
checks |
list |
Any checks that must be checked before the command can run |
cooldown |
Cooldown |
An optional cooldown to apply to the command |
max_concurrency |
MaxConcurrency |
An optional maximum number of concurrent instances to apply to the command |
callback |
Callable[..., Coroutine] |
The coroutine to be called for this command |
error_callback |
Callable[..., Coroutine] |
The coroutine to be called when an error occurs |
pre_run_callback |
Callable[..., Coroutine] |
The coroutine to be called before the command is executed, but after the checks |
post_run_callback |
Callable[..., Coroutine] |
The coroutine to be called after the command has executed |
Source code in naff/models/naff/command.py
@define()
class BaseCommand(DictSerializationMixin, CallbackObject):
"""
An object all commands inherit from. Outlines the basic structure of a command, and handles checks.
Attributes:
extension: The extension this command belongs to.
enabled: Whether this command is enabled
checks: Any checks that must be run before this command can be run
callback: The coroutine to be called for this command
error_callback: The coroutine to be called when an error occurs
pre_run_callback: A coroutine to be called before this command is run **but** after the checks
post_run_callback: A coroutine to be called after this command has run
"""
extension: Any = field(default=None, metadata=docs("The extension this command belongs to") | no_export_meta)
enabled: bool = field(default=True, metadata=docs("Whether this can be run at all") | no_export_meta)
checks: list = field(
factory=list, metadata=docs("Any checks that must be *checked* before the command can run") | no_export_meta
)
cooldown: Cooldown = field(
default=MISSING, metadata=docs("An optional cooldown to apply to the command") | no_export_meta
)
max_concurrency: MaxConcurrency = field(
default=MISSING,
metadata=docs("An optional maximum number of concurrent instances to apply to the command") | no_export_meta,
)
callback: Callable[..., Coroutine] = field(
default=None, metadata=docs("The coroutine to be called for this command") | no_export_meta
)
error_callback: Callable[..., Coroutine] = field(
default=None, metadata=no_export_meta | docs("The coroutine to be called when an error occurs")
)
pre_run_callback: Callable[..., Coroutine] = field(
default=None,
metadata=no_export_meta
| docs("The coroutine to be called before the command is executed, **but** after the checks"),
)
post_run_callback: Callable[..., Coroutine] = field(
default=None, metadata=no_export_meta | docs("The coroutine to be called after the command has executed")
)
def __attrs_post_init__(self) -> None:
if self.callback is not None:
if hasattr(self.callback, "checks"):
self.checks += self.callback.checks
if hasattr(self.callback, "cooldown"):
self.cooldown = self.callback.cooldown
if hasattr(self.callback, "max_concurrency"):
self.max_concurrency = self.callback.max_concurrency
def __hash__(self) -> int:
return id(self)
async def __call__(self, context: "Context", *args, **kwargs) -> None:
"""
Calls this command.
Args:
context: The context of this command
args: Any
kwargs: Any
"""
# signals if a semaphore has been acquired, for exception handling
# if present assume one will be acquired
max_conc_acquired = self.max_concurrency is not MISSING
try:
if await self._can_run(context):
if self.pre_run_callback is not None:
await self.call_with_binding(self.pre_run_callback, context, *args, **kwargs)
if self.extension is not None and self.extension.extension_prerun:
for prerun in self.extension.extension_prerun:
await prerun(context, *args, **kwargs)
await self.call_callback(self.callback, context)
if self.post_run_callback is not None:
await self.call_with_binding(self.post_run_callback, context, *args, **kwargs)
if self.extension is not None and self.extension.extension_postrun:
for postrun in self.extension.extension_postrun:
await postrun(context, *args, **kwargs)
except Exception as e:
# if a MaxConcurrencyReached-exception is raised a connection was never acquired
max_conc_acquired = not isinstance(e, MaxConcurrencyReached)
if self.error_callback:
await self.error_callback(e, context, *args, **kwargs)
elif self.extension and self.extension.extension_error:
await self.extension.extension_error(e, context, *args, **kwargs)
else:
raise
finally:
if self.max_concurrency is not MISSING and max_conc_acquired:
await self.max_concurrency.release(context)
@staticmethod
def _get_converter_function(anno: type[Converter] | Converter, name: str) -> Callable[[Context, str], Any]:
num_params = len(get_parameters(anno.convert))
# if we have three parameters for the function, it's likely it has a self parameter
# so we need to get rid of it by initing - typehinting hates this, btw!
# the below line will error out if we aren't supposed to init it, so that works out
try:
actual_anno: Converter = anno() if num_params == 3 else anno # type: ignore
except TypeError:
raise ValueError(
f"{get_object_name(anno)} for {name} is invalid: converters must have exactly 2 arguments."
) from None
# we can only get to this point while having three params if we successfully inited
if num_params == 3:
num_params -= 1
if num_params != 2:
raise ValueError(
f"{get_object_name(anno)} for {name} is invalid: converters must have exactly 2 arguments."
)
return actual_anno.convert
async def try_convert(self, converter: Optional[Callable], context: "Context", value: Any) -> Any:
if converter is None:
return value
return await maybe_coroutine(converter, context, value)
def param_config(self, annotation: Any, name: str) -> Tuple[Callable, Optional[dict]]:
# This thing is complicated. NAFF-annotations can either be annotated directly, or they can be annotated with Annotated[str, CMD_*]
# This helper function handles both cases, and returns a tuple of the converter and its config (if any)
if annotation is None:
return None
if typing.get_origin(annotation) is Annotated and (args := typing.get_args(annotation)):
for ann in args:
v = getattr(ann, name, None)
if v is not None:
return (ann, v)
return (annotation, getattr(annotation, name, None))
async def call_callback(self, callback: Callable, context: "Context") -> None:
_call = callback
if self.has_binding:
callback = functools.partial(callback, None, None)
else:
callback = functools.partial(callback, None)
parameters = get_parameters(callback)
args = []
kwargs = {}
if len(parameters) == 0:
# if no params, user only wants context
return await self.call_with_binding(_call, context)
c_args = copy.copy(context.args)
for param in parameters.values():
if isinstance(param.annotation, Converter):
# for any future dev looking at this:
# this checks if the class here has a convert function
# it does NOT check if the annotation is actually a subclass of Converter
# this is an intended behavior for Protocols with the runtime_checkable decorator
convert = functools.partial(
self.try_convert, self._get_converter_function(param.annotation, param.name), context
)
else:
convert = functools.partial(self.try_convert, None, context)
func, config = self.param_config(param.annotation, "_annotation_dat")
if config:
# if user has used an naff-annotation, run the annotation, and pass the result to the user
local = {"context": context, "extension": self.extension, "param": param.name}
ano_args = [local[c] for c in config["args"]]
if param.kind != param.POSITIONAL_ONLY:
kwargs[param.name] = func(*ano_args)
else:
args.append(func(*ano_args))
continue
elif param.name in context.kwargs:
# if parameter is in kwargs, user obviously wants it, pass it
if param.kind != param.POSITIONAL_ONLY:
kwargs[param.name] = await convert(context.kwargs[param.name])
else:
args.append(await convert(context.kwargs[param.name]))
if context.kwargs[param.name] in c_args:
c_args.remove(context.kwargs[param.name])
elif param.default is not param.empty:
kwargs[param.name] = param.default
else:
if not str(param).startswith("*"):
if param.kind != param.KEYWORD_ONLY:
try:
args.append(await convert(c_args.pop(0)))
except IndexError:
raise ValueError(
f"{context.invoke_target} expects {len([p for p in parameters.values() if p.default is p.empty]) + len(callback.args)}"
f" arguments but received {len(context.args)} instead"
) from None
else:
raise ValueError(f"Unable to resolve argument: {param.name}")
if any(kwargs_reg.match(str(param)) for param in parameters.values()):
# if user has `**kwargs` pass all remaining kwargs
kwargs = kwargs | {k: v for k, v in context.kwargs.items() if k not in kwargs}
if any(args_reg.match(str(param)) for param in parameters.values()):
# user has `*args` pass all remaining args
args = args + [await convert(c) for c in c_args]
return await self.call_with_binding(_call, context, *args, **kwargs)
async def _can_run(self, context: Context) -> bool:
"""
Determines if this command can be run.
Args:
context: The context of the command
"""
max_conc_acquired = False # signals if a semaphore has been acquired, for exception handling
try:
if not self.enabled:
return False
for _c in self.checks:
if not await _c(context):
raise CommandCheckFailure(self, _c, context)
if self.extension and self.extension.extension_checks:
for _c in self.extension.extension_checks:
if not await _c(context):
raise CommandCheckFailure(self, _c, context)
if self.max_concurrency is not MISSING:
if not await self.max_concurrency.acquire(context):
raise MaxConcurrencyReached(self, self.max_concurrency)
if self.cooldown is not MISSING:
if not await self.cooldown.acquire_token(context):
raise CommandOnCooldown(self, await self.cooldown.get_cooldown(context))
return True
except Exception:
if max_conc_acquired:
await self.max_concurrency.release(context)
raise
def error(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
"""A decorator to declare a coroutine as one that will be run upon an error."""
if not asyncio.iscoroutinefunction(call):
raise TypeError("Error handler must be coroutine")
self.error_callback = call
return call
def pre_run(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
"""A decorator to declare a coroutine as one that will be run before the command."""
if not asyncio.iscoroutinefunction(call):
raise TypeError("pre_run must be coroutine")
self.pre_run_callback = call
return call
def post_run(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
"""A decorator to declare a coroutine as one that will be run after the command has."""
if not asyncio.iscoroutinefunction(call):
raise TypeError("post_run must be coroutine")
self.post_run_callback = call
return call
async
inherited
method
call_with_binding(self, callback, *args, **kwargs)
¶
Call a given method using this objects _binding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
callback |
Callable[..., Coroutine[Any, Any, Any]] |
The callback to call. |
required |
Source code in naff/models/naff/command.py
async def call_with_binding(self, callback: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> Any:
"""
Call a given method using this objects _binding.
Args:
callback: The callback to call.
"""
if self._binding:
return await callback(self._binding, *args, **kwargs)
return await callback(*args, **kwargs)
inherited
method
update_from_dict(self, data)
¶
Updates object attribute(s) with new json data received from discord api.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Dict[str, Any] |
The json data received from discord api. |
required |
Returns:
Type | Description |
---|---|
~T |
The updated object class instance. |
Source code in naff/models/naff/command.py
def update_from_dict(self: Type[const.T], data: Dict[str, Any]) -> const.T:
"""
Updates object attribute(s) with new json data received from discord api.
Args:
data: The json data received from discord api.
Returns:
The updated object class instance.
"""
data = self._process_dict(data)
for key, value in self._filter_kwargs(data, self._get_keys()).items():
# todo improve
setattr(self, key, value)
return self
inherited
method
to_dict(self)
¶
Exports object into dictionary representation, ready to be sent to discord api.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The exported dictionary. |
Source code in naff/models/naff/command.py
def to_dict(self) -> Dict[str, Any]:
"""
Exports object into dictionary representation, ready to be sent to discord api.
Returns:
The exported dictionary.
"""
self._check_object()
return serializer.to_dict(self)
method
error(self, call)
¶
A decorator to declare a coroutine as one that will be run upon an error.
Source code in naff/models/naff/command.py
def error(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
"""A decorator to declare a coroutine as one that will be run upon an error."""
if not asyncio.iscoroutinefunction(call):
raise TypeError("Error handler must be coroutine")
self.error_callback = call
return call
method
pre_run(self, call)
¶
A decorator to declare a coroutine as one that will be run before the command.
Source code in naff/models/naff/command.py
def pre_run(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
"""A decorator to declare a coroutine as one that will be run before the command."""
if not asyncio.iscoroutinefunction(call):
raise TypeError("pre_run must be coroutine")
self.pre_run_callback = call
return call
method
post_run(self, call)
¶
A decorator to declare a coroutine as one that will be run after the command has.
Source code in naff/models/naff/command.py
def post_run(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
"""A decorator to declare a coroutine as one that will be run after the command has."""
if not asyncio.iscoroutinefunction(call):
raise TypeError("post_run must be coroutine")
self.post_run_callback = call
return call
function
check(check)
¶
Add a check to a command.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
check |
Callable[['Context'], Awaitable[bool]] |
A coroutine as a check for this command |
required |
Source code in naff/models/naff/command.py
def check(check: Callable[["Context"], Awaitable[bool]]) -> Callable[[Coroutine], Coroutine]:
"""
Add a check to a command.
Args:
check: A coroutine as a check for this command
"""
def wrapper(coro: Coroutine) -> Coroutine:
if isinstance(coro, BaseCommand):
coro.checks.append(check)
return coro
if not hasattr(coro, "checks"):
coro.checks = []
coro.checks.append(check)
return coro
return wrapper
function
cooldown(bucket, rate, interval)
¶
Add a cooldown to a command.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bucket |
Buckets |
The bucket used to track cooldowns |
required |
rate |
int |
How many commands may be ran per interval |
required |
interval |
float |
How many seconds to wait for a cooldown |
required |
Source code in naff/models/naff/command.py
def cooldown(bucket: Buckets, rate: int, interval: float) -> Callable[[Coroutine], Coroutine]:
"""
Add a cooldown to a command.
Args:
bucket: The bucket used to track cooldowns
rate: How many commands may be ran per interval
interval: How many seconds to wait for a cooldown
"""
def wrapper(coro: Coroutine) -> Coroutine:
cooldown_obj = Cooldown(bucket, rate, interval)
coro.cooldown = cooldown_obj
return coro
return wrapper
function
max_concurrency(bucket, concurrent)
¶
Add a maximum number of concurrent instances to the command.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bucket |
Buckets |
The bucket to enforce the maximum within |
required |
concurrent |
int |
The maximum number of concurrent instances to allow |
required |
Source code in naff/models/naff/command.py
def max_concurrency(bucket: Buckets, concurrent: int) -> Callable[[Coroutine], Coroutine]:
"""
Add a maximum number of concurrent instances to the command.
Args:
bucket: The bucket to enforce the maximum within
concurrent: The maximum number of concurrent instances to allow
"""
def wrapper(coro: Coroutine) -> Coroutine:
max_conc = MaxConcurrency(concurrent, bucket)
coro.max_concurrency = max_conc
return coro
return wrapper