import py_compile from asyncio import run_coroutine_threadsafe from pathlib import Path from tempfile import NamedTemporaryFile from typing import Generator, List, Sequence import discord from red_commons.logging import RedTraceLogger, getLogger from redbot.core import Config, checks, commands from redbot.core.bot import Red from redbot.core.core_commands import CoreLogic from redbot.core.utils.chat_formatting import bold, box, humanize_list from typing_extensions import override from watchdog.events import FileSystemEvent, FileSystemMovedEvent, RegexMatchingEventHandler from watchdog.observers import Observer from watchdog.observers.api import BaseObserver class HotReload(commands.Cog): """Automatically reload cogs in local cog paths on file change.""" __author__ = ["[cswimr](https://www.coastalcommits.com/cswimr)"] __git__ = "https://www.coastalcommits.com/cswimr/SeaCogs" __version__ = "1.4.1" __documentation__ = "https://seacogs.coastalcommits.com/hotreload/" def __init__(self, bot: Red) -> None: super().__init__() self.bot: Red = bot self.config: Config = Config.get_conf(cog_instance=self, identifier=294518358420750336, force_registration=True) self.logger: RedTraceLogger = getLogger(name="red.SeaCogs.HotReload") self.observers: List[BaseObserver] = [] self.config.register_global(notify_channel=None, compile_before_reload=False) watchdog_loggers = [getLogger(name="watchdog.observers.inotify_buffer")] for watchdog_logger in watchdog_loggers: watchdog_logger.setLevel("INFO") # SHUT UP!!!! @override async def cog_load(self) -> None: """Start the observer when the cog is loaded.""" _ = self.bot.loop.create_task(self.start_observer()) @override async def cog_unload(self) -> None: """Stop the observer when the cog is unloaded.""" for observer in self.observers: observer.stop() observer.join() self.logger.info("Stopped observer. No longer watching for file changes.") @override def format_help_for_context(self, ctx: commands.Context) -> str: pre_processed = super().format_help_for_context(ctx) or "" n = "\n" if "\n\n" not in pre_processed else "" text = [ f"{pre_processed}{n}", f"{bold('Cog Version:')} [{self.__version__}]({self.__git__})", f"{bold('Author:')} {humanize_list(self.__author__)}", f"{bold('Documentation:')} {self.__documentation__}", ] return "\n".join(text) async def get_paths(self) -> Generator[Path, None, None]: """Retrieve user defined paths.""" cog_manager = self.bot._cog_mgr # noqa: SLF001 # We have to use this private method because there is no public API to get user defined paths cog_paths = await cog_manager.user_defined_paths() return (Path(path) for path in cog_paths) async def start_observer(self) -> None: """Start the observer to watch for file changes.""" self.observers.append(Observer()) paths = await self.get_paths() is_first = True for observer in self.observers: if not is_first: observer.stop() observer.join() self.logger.debug("Stopped hanging observer.") continue for path in paths: if not path.exists(): self.logger.warning("Path %s does not exist. Skipping.", path) continue self.logger.debug("Adding observer schedule for path %s.", path) observer.schedule(event_handler=HotReloadHandler(cog=self, path=path), path=str(path), recursive=True) observer.start() self.logger.info("Started observer. Watching for file changes.") is_first = False @checks.is_owner() @commands.group(name="hotreload") async def hotreload_group(self, ctx: commands.Context) -> None: """HotReload configuration commands.""" pass @hotreload_group.command(name="notifychannel") async def hotreload_notifychannel(self, ctx: commands.Context, channel: discord.TextChannel) -> None: """Set the channel to send notifications to.""" await self.config.notify_channel.set(channel.id) await ctx.send(f"Notifications will be sent to {channel.mention}.") @hotreload_group.command(name="compile") # type: ignore async def hotreload_compile(self, ctx: commands.Context, compile_before_reload: bool) -> None: """Set whether to compile modified files before reloading.""" await self.config.compile_before_reload.set(compile_before_reload) await ctx.send(f"I {'will' if compile_before_reload else 'will not'} compile modified files before hotreloading cogs.") @hotreload_group.command(name="list") # type: ignore async def hotreload_list(self, ctx: commands.Context) -> None: """List the currently active observers.""" if not self.observers: await ctx.send("No observers are currently active.") return await ctx.send(f"Currently active observers (If there are more than one of these, report an issue): {box(humanize_list([str(o) for o in self.observers], style='unit'))}") class HotReloadHandler(RegexMatchingEventHandler): """Handler for file changes.""" def __init__(self, cog: HotReload, path: Path) -> None: super().__init__(regexes=[r".*\.py$"]) self.cog: HotReload = cog self.path: Path = path self.logger: RedTraceLogger = getLogger(name="red.SeaCogs.HotReload.Observer") def on_any_event(self, event: FileSystemEvent) -> None: """Handle filesystem events.""" if event.is_directory: return allowed_events = ("moved", "deleted", "created", "modified") if event.event_type not in allowed_events: return relative_src_path = Path(str(event.src_path)).relative_to(self.path) src_package_name = relative_src_path.parts[0] cogs_to_reload = [src_package_name] if isinstance(event, FileSystemMovedEvent): dest = f" to {event.dest_path}" relative_dest_path = Path(str(event.dest_path)).relative_to(self.path) dest_package_name = relative_dest_path.parts[0] if dest_package_name != src_package_name: cogs_to_reload.append(dest_package_name) else: dest = "" self.logger.info("File %s has been %s%s.", event.src_path, event.event_type, dest) run_coroutine_threadsafe( coro=self.reload_cogs( cog_names=cogs_to_reload, paths=[Path(str(p)) for p in (event.src_path, getattr(event, "dest_path", None)) if p], ), loop=self.cog.bot.loop, ) async def reload_cogs(self, cog_names: Sequence[str], paths: Sequence[Path]) -> None: """Reload modified cogs.""" if not self.compile_modified_files(cog_names, paths): return core_logic = CoreLogic(bot=self.cog.bot) self.logger.info("Reloading cogs: %s", humanize_list(cog_names, style="unit")) await core_logic._reload(pkg_names=cog_names) # noqa: SLF001 # We have to use this private method because there is no public API to reload other cogs self.logger.info("Reloaded cogs: %s", humanize_list(cog_names, style="unit")) channel = self.cog.bot.get_channel(await self.cog.config.notify_channel()) if channel and isinstance(channel, discord.TextChannel): await channel.send(f"Reloaded cogs: {humanize_list(cog_names, style='unit')}") def compile_modified_files(self, cog_names: Sequence[str], paths: Sequence[Path]) -> bool: """Compile modified files to ensure they are valid Python files.""" for path in paths: if not path.exists() or path.suffix != ".py": self.logger.debug("Path %s does not exist or does not point to a Python file. Skipping compilation step.", path) continue try: with NamedTemporaryFile() as temp_file: self.logger.debug("Attempting to compile %s", path) py_compile.compile(file=str(path), cfile=temp_file.name, doraise=True) self.logger.debug("Successfully compiled %s", path) except py_compile.PyCompileError as e: e.__suppress_context__ = True self.logger.exception("%s failed to compile. Not reloading cogs %s.", path, humanize_list(cog_names, style="unit")) return False except OSError: self.logger.exception("Failed to create tempfile for compilation step. Skipping.") return True