diff --git a/.docs/hotreload.md b/.docs/hotreload.md index e2d4cb3..0144e5d 100644 --- a/.docs/hotreload.md +++ b/.docs/hotreload.md @@ -13,6 +13,10 @@ This is useful for development, as it allows you to make changes to your cogs an ## Commands +### hotreload compile + +Determines if the cog should try to compile a modified Python file before reloading the associated cog. Useful for catching syntax errors. Disabled by default. + ### hotreload notifychannel Set the channel where hotreload will send notifications when a cog is reloaded. diff --git a/hotreload/hotreload.py b/hotreload/hotreload.py index 9357fcd..ee223f2 100644 --- a/hotreload/hotreload.py +++ b/hotreload/hotreload.py @@ -1,5 +1,7 @@ +import py_compile from asyncio import run_coroutine_threadsafe from pathlib import Path +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, List, Sequence, Tuple from red_commons.logging import RedTraceLogger, getLogger @@ -19,7 +21,7 @@ class HotReload(commands.Cog): __author__ = ["[cswimr](https://www.coastalcommits.com/cswimr)"] __git__ = "https://www.coastalcommits.com/cswimr/SeaCogs" - __version__ = "1.3.3" + __version__ = "1.4.0" __documentation__ = "https://seacogs.coastalcommits.com/hotreload/" def __init__(self, bot: Red) -> None: @@ -28,7 +30,7 @@ class HotReload(commands.Cog): self.config = Config.get_conf(self, identifier=294518358420750336, force_registration=True) self.logger: RedTraceLogger = getLogger(name="red.SeaCogs.HotReload") self.observers: List[ObserverType] = [] - self.config.register_global(notify_channel=None) + self.config.register_global(notify_channel=None, compile_before_reload=False) watchdog_loggers = [getLogger(name="watchdog.observers.inotify_buffer")] for watchdog_logger in watchdog_loggers: watchdog_logger.setLevel("INFO") # SHUT UP!!!! @@ -94,6 +96,12 @@ class HotReload(commands.Cog): await self.config.notify_channel.set(channel.id) await ctx.send(f"Notifications will be sent to {channel.mention}.") + @hotreload_group.command(name="compile") + async def hotreload_compile(self, ctx: commands.Context, compile_before_reload: bool) -> None: + """Set whether to compile modified files before reloading.""" + await self.config.compile_before_reload.set(compile_before_reload) + await ctx.send(f"I {'will' if compile_before_reload else 'will not'} compile modified files before hotreloading cogs.") + @hotreload_group.command(name="list") async def hotreload_list(self, ctx: commands.Context) -> None: """List the currently active observers.""" @@ -136,10 +144,19 @@ class HotReloadHandler(RegexMatchingEventHandler): self.logger.info("File %s has been %s%s.", event.src_path, event.event_type, dest) - run_coroutine_threadsafe(self.reload_cogs(cogs_to_reload), loop=self.cog.bot.loop) + run_coroutine_threadsafe( + coro=self.reload_cogs( + cog_names=cogs_to_reload, + paths=[Path(p) for p in (event.src_path, getattr(event, "dest_path", None)) if p], + ), + loop=self.cog.bot.loop, + ) + + async def reload_cogs(self, cog_names: Sequence[str], paths: Sequence[Path]) -> None: + """Reload modified cogs.""" + if not self.compile_modified_files(cog_names, paths): + return - async def reload_cogs(self, cog_names: Sequence[str]) -> None: - """Reload modified cog.""" core_logic = CoreLogic(bot=self.cog.bot) self.logger.info("Reloading cogs: %s", humanize_list(cog_names, style="unit")) await core_logic._reload(pkg_names=cog_names) # noqa: SLF001 # We have to use this private method because there is no public API to reload other cogs @@ -148,3 +165,24 @@ class HotReloadHandler(RegexMatchingEventHandler): channel = self.cog.bot.get_channel(await self.cog.config.notify_channel()) if channel: await channel.send(f"Reloaded cogs: {humanize_list(cog_names, style='unit')}") + + def compile_modified_files(self, cog_names: Sequence[str], paths: Sequence[Path]) -> bool: + """Compile modified files to ensure they are valid Python files.""" + for path in paths: + if not path.exists() or path.suffix != ".py": + self.logger.debug("Path %s does not exist or does not point to a Python file. Skipping compilation step.", path) + continue + + try: + with NamedTemporaryFile() as temp_file: + self.logger.debug("Attempting to compile %s", path) + py_compile.compile(file=path, cfile=temp_file.name, doraise=True) + self.logger.debug("Successfully compiled %s", path) + + except py_compile.PyCompileError as e: + e.__suppress_context__ = True + self.logger.exception("%s failed to compile. Not reloading cogs %s.", path, humanize_list(cog_names, style="unit")) + return False + except OSError: + self.logger.exception("Failed to create tempfile for compilation step. Skipping.") + return True