SeaCogs/hotreload/hotreload.py

189 lines
8.5 KiB
Python
Raw Normal View History

import py_compile
from asyncio import run_coroutine_threadsafe
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, List, Sequence, Tuple
from red_commons.logging import RedTraceLogger, getLogger
from redbot.core import Config, checks, commands
from redbot.core.bot import Red
from redbot.core.core_commands import CoreLogic
from redbot.core.utils.chat_formatting import bold, box, humanize_list
from watchdog.events import FileSystemEvent, FileSystemMovedEvent, RegexMatchingEventHandler
from watchdog.observers import Observer
if TYPE_CHECKING:
from watchdog.observers import ObserverType
class HotReload(commands.Cog):
"""Automatically reload cogs in local cog paths on file change."""
__author__ = ["[cswimr](https://www.coastalcommits.com/cswimr)"]
__git__ = "https://www.coastalcommits.com/cswimr/SeaCogs"
__version__ = "1.4.0"
__documentation__ = "https://seacogs.coastalcommits.com/hotreload/"
def __init__(self, bot: Red) -> None:
super().__init__()
self.bot: Red = bot
self.config = Config.get_conf(self, identifier=294518358420750336, force_registration=True)
self.logger: RedTraceLogger = getLogger(name="red.SeaCogs.HotReload")
self.observers: List[ObserverType] = []
self.config.register_global(notify_channel=None, compile_before_reload=False)
watchdog_loggers = [getLogger(name="watchdog.observers.inotify_buffer")]
for watchdog_logger in watchdog_loggers:
watchdog_logger.setLevel("INFO") # SHUT UP!!!!
async def cog_load(self) -> None:
"""Start the observer when the cog is loaded."""
self.bot.loop.create_task(self.start_observer())
async def cog_unload(self) -> None:
"""Stop the observer when the cog is unloaded."""
for observer in self.observers:
observer.stop()
observer.join()
self.logger.info("Stopped observer. No longer watching for file changes.")
def format_help_for_context(self, ctx: commands.Context) -> str:
pre_processed = super().format_help_for_context(ctx) or ""
n = "\n" if "\n\n" not in pre_processed else ""
text = [
f"{pre_processed}{n}",
f"{bold('Cog Version:')} [{self.__version__}]({self.__git__})",
f"{bold('Author:')} {humanize_list(self.__author__)}",
f"{bold('Documentation:')} {self.__documentation__}",
]
return "\n".join(text)
async def get_paths(self) -> Tuple[Path]:
"""Retrieve user defined paths."""
2025-01-26 14:10:33 +00:00
cog_manager = self.bot._cog_mgr # noqa: SLF001 # We have to use this private method because there is no public API to get user defined paths
cog_paths = await cog_manager.user_defined_paths()
return (Path(path) for path in cog_paths)
async def start_observer(self) -> None:
"""Start the observer to watch for file changes."""
self.observers.append(Observer())
paths = await self.get_paths()
is_first = True
for observer in self.observers:
if not is_first:
observer.stop()
observer.join()
self.logger.debug("Stopped hanging observer.")
continue
for path in paths:
if not path.exists():
self.logger.warning("Path %s does not exist. Skipping.", path)
continue
self.logger.debug("Adding observer schedule for path %s.", path)
observer.schedule(event_handler=HotReloadHandler(cog=self, path=path), path=path, recursive=True)
observer.start()
self.logger.info("Started observer. Watching for file changes.")
is_first = False
@checks.is_owner()
@commands.group(name="hotreload")
async def hotreload_group(self, ctx: commands.Context) -> None:
"""HotReload configuration commands."""
pass
@hotreload_group.command(name="notifychannel")
async def hotreload_notifychannel(self, ctx: commands.Context, channel: commands.TextChannelConverter) -> None:
"""Set the channel to send notifications to."""
await self.config.notify_channel.set(channel.id)
await ctx.send(f"Notifications will be sent to {channel.mention}.")
@hotreload_group.command(name="compile")
async def hotreload_compile(self, ctx: commands.Context, compile_before_reload: bool) -> None:
"""Set whether to compile modified files before reloading."""
await self.config.compile_before_reload.set(compile_before_reload)
await ctx.send(f"I {'will' if compile_before_reload else 'will not'} compile modified files before hotreloading cogs.")
@hotreload_group.command(name="list")
async def hotreload_list(self, ctx: commands.Context) -> None:
"""List the currently active observers."""
if not self.observers:
await ctx.send("No observers are currently active.")
return
await ctx.send(f"Currently active observers (If there are more than one of these, report an issue): {box(humanize_list(self.observers, style='unit'))}")
class HotReloadHandler(RegexMatchingEventHandler):
"""Handler for file changes."""
def __init__(self, cog: HotReload, path: Path) -> None:
super().__init__(regexes=[r".*\.py$"])
self.cog: HotReload = cog
self.path: Path = path
self.logger: RedTraceLogger = getLogger(name="red.SeaCogs.HotReload.Observer")
def on_any_event(self, event: FileSystemEvent) -> None:
"""Handle filesystem events."""
if event.is_directory:
return
allowed_events = ("moved", "deleted", "created", "modified")
if event.event_type not in allowed_events:
return
relative_src_path = Path(event.src_path).relative_to(self.path)
src_package_name = relative_src_path.parts[0]
cogs_to_reload = [src_package_name]
if isinstance(event, FileSystemMovedEvent):
dest = f" to {event.dest_path}"
relative_dest_path = Path(event.dest_path).relative_to(self.path)
dest_package_name = relative_dest_path.parts[0]
if dest_package_name != src_package_name:
cogs_to_reload.append(dest_package_name)
else:
dest = ""
self.logger.info("File %s has been %s%s.", event.src_path, event.event_type, dest)
run_coroutine_threadsafe(
coro=self.reload_cogs(
cog_names=cogs_to_reload,
paths=[Path(p) for p in (event.src_path, getattr(event, "dest_path", None)) if p],
),
loop=self.cog.bot.loop,
)
async def reload_cogs(self, cog_names: Sequence[str], paths: Sequence[Path]) -> None:
"""Reload modified cogs."""
if not self.compile_modified_files(cog_names, paths):
return
core_logic = CoreLogic(bot=self.cog.bot)
self.logger.info("Reloading cogs: %s", humanize_list(cog_names, style="unit"))
2025-01-26 14:13:37 +00:00
await core_logic._reload(pkg_names=cog_names) # noqa: SLF001 # We have to use this private method because there is no public API to reload other cogs
self.logger.info("Reloaded cogs: %s", humanize_list(cog_names, style="unit"))
channel = self.cog.bot.get_channel(await self.cog.config.notify_channel())
if channel:
await channel.send(f"Reloaded cogs: {humanize_list(cog_names, style='unit')}")
def compile_modified_files(self, cog_names: Sequence[str], paths: Sequence[Path]) -> bool:
"""Compile modified files to ensure they are valid Python files."""
for path in paths:
if not path.exists() or path.suffix != ".py":
self.logger.debug("Path %s does not exist or does not point to a Python file. Skipping compilation step.", path)
continue
try:
with NamedTemporaryFile() as temp_file:
self.logger.debug("Attempting to compile %s", path)
py_compile.compile(file=path, cfile=temp_file.name, doraise=True)
self.logger.debug("Successfully compiled %s", path)
except py_compile.PyCompileError as e:
e.__suppress_context__ = True
self.logger.exception("%s failed to compile. Not reloading cogs %s.", path, humanize_list(cog_names, style="unit"))
return False
except OSError:
self.logger.exception("Failed to create tempfile for compilation step. Skipping.")
return True