2025-01-29 23:24:19 +00:00
import py_compile
2025-01-25 18:51:27 -05:00
from asyncio import run_coroutine_threadsafe
from pathlib import Path
2025-01-29 23:24:19 +00:00
from tempfile import NamedTemporaryFile
2025-02-01 16:57:45 +00:00
from typing import Generator , List , Sequence
2025-01-25 18:51:27 -05:00
2025-02-01 16:57:45 +00:00
import discord
2025-01-25 18:51:27 -05:00
from red_commons . logging import RedTraceLogger , getLogger
2025-01-26 10:06:14 -05:00
from redbot . core import Config , checks , commands
2025-01-25 18:51:27 -05:00
from redbot . core . bot import Red
from redbot . core . core_commands import CoreLogic
2025-01-26 15:18:05 +00:00
from redbot . core . utils . chat_formatting import bold , box , humanize_list
2025-02-01 16:57:45 +00:00
from typing_extensions import override
2025-01-25 19:25:29 -05:00
from watchdog . events import FileSystemEvent , FileSystemMovedEvent , RegexMatchingEventHandler
2025-01-26 21:23:08 +00:00
from watchdog . observers import Observer
2025-02-01 16:57:45 +00:00
from watchdog . observers . api import BaseObserver
2025-01-25 18:51:27 -05:00
class HotReload ( commands . Cog ) :
""" Automatically reload cogs in local cog paths on file change. """
__author__ = [ " [cswimr](https://www.coastalcommits.com/cswimr) " ]
__git__ = " https://www.coastalcommits.com/cswimr/SeaCogs "
2025-02-01 16:57:45 +00:00
__version__ = " 1.4.1 "
2025-01-25 18:51:27 -05:00
__documentation__ = " https://seacogs.coastalcommits.com/hotreload/ "
def __init__ ( self , bot : Red ) - > None :
super ( ) . __init__ ( )
self . bot : Red = bot
2025-02-01 16:57:45 +00:00
self . config : Config = Config . get_conf ( cog_instance = self , identifier = 294518358420750336 , force_registration = True )
2025-01-25 18:51:27 -05:00
self . logger : RedTraceLogger = getLogger ( name = " red.SeaCogs.HotReload " )
2025-02-01 16:57:45 +00:00
self . observers : List [ BaseObserver ] = [ ]
2025-01-29 23:24:19 +00:00
self . config . register_global ( notify_channel = None , compile_before_reload = False )
2025-01-25 18:51:27 -05:00
watchdog_loggers = [ getLogger ( name = " watchdog.observers.inotify_buffer " ) ]
for watchdog_logger in watchdog_loggers :
watchdog_logger . setLevel ( " INFO " ) # SHUT UP!!!!
2025-02-01 16:57:45 +00:00
@override
2025-01-26 00:46:00 +00:00
async def cog_load ( self ) - > None :
2025-01-25 18:51:27 -05:00
""" Start the observer when the cog is loaded. """
2025-02-01 16:57:45 +00:00
_ = self . bot . loop . create_task ( self . start_observer ( ) )
2025-01-25 18:51:27 -05:00
2025-02-01 16:57:45 +00:00
@override
2025-01-26 00:46:00 +00:00
async def cog_unload ( self ) - > None :
2025-01-25 18:51:27 -05:00
""" Stop the observer when the cog is unloaded. """
2025-01-26 15:18:05 +00:00
for observer in self . observers :
observer . stop ( )
observer . join ( )
2025-01-25 18:51:27 -05:00
self . logger . info ( " Stopped observer. No longer watching for file changes. " )
2025-02-01 16:57:45 +00:00
@override
2025-01-25 18:51:27 -05:00
def format_help_for_context ( self , ctx : commands . Context ) - > str :
pre_processed = super ( ) . format_help_for_context ( ctx ) or " "
n = " \n " if " \n \n " not in pre_processed else " "
text = [
f " { pre_processed } { n } " ,
f " { bold ( ' Cog Version: ' ) } [ { self . __version__ } ]( { self . __git__ } ) " ,
f " { bold ( ' Author: ' ) } { humanize_list ( self . __author__ ) } " ,
f " { bold ( ' Documentation: ' ) } { self . __documentation__ } " ,
]
return " \n " . join ( text )
2025-02-06 17:20:21 -06:00
async def get_paths ( self ) - > Generator [ Path , None , None ] :
2025-01-25 18:51:27 -05:00
""" Retrieve user defined paths. """
2025-01-26 14:10:33 +00:00
cog_manager = self . bot . _cog_mgr # noqa: SLF001 # We have to use this private method because there is no public API to get user defined paths
2025-01-25 18:51:27 -05:00
cog_paths = await cog_manager . user_defined_paths ( )
return ( Path ( path ) for path in cog_paths )
async def start_observer ( self ) - > None :
""" Start the observer to watch for file changes. """
2025-01-26 15:18:05 +00:00
self . observers . append ( Observer ( ) )
2025-01-25 18:51:27 -05:00
paths = await self . get_paths ( )
2025-01-26 15:18:05 +00:00
is_first = True
for observer in self . observers :
if not is_first :
observer . stop ( )
observer . join ( )
self . logger . debug ( " Stopped hanging observer. " )
continue
for path in paths :
2025-01-26 21:29:02 +00:00
if not path . exists ( ) :
self . logger . warning ( " Path %s does not exist. Skipping. " , path )
continue
self . logger . debug ( " Adding observer schedule for path %s . " , path )
2025-02-01 16:57:45 +00:00
observer . schedule ( event_handler = HotReloadHandler ( cog = self , path = path ) , path = str ( path ) , recursive = True )
2025-01-26 15:18:05 +00:00
observer . start ( )
self . logger . info ( " Started observer. Watching for file changes. " )
is_first = False
2025-01-25 18:51:27 -05:00
2025-01-26 10:06:14 -05:00
@checks.is_owner ( )
@commands.group ( name = " hotreload " )
async def hotreload_group ( self , ctx : commands . Context ) - > None :
""" HotReload configuration commands. """
pass
@hotreload_group.command ( name = " notifychannel " )
2025-02-01 16:57:45 +00:00
async def hotreload_notifychannel ( self , ctx : commands . Context , channel : discord . TextChannel ) - > None :
2025-01-26 10:06:14 -05:00
""" Set the channel to send notifications to. """
await self . config . notify_channel . set ( channel . id )
await ctx . send ( f " Notifications will be sent to { channel . mention } . " )
2025-02-01 16:57:45 +00:00
@hotreload_group.command ( name = " compile " ) # type: ignore
2025-01-29 23:24:19 +00:00
async def hotreload_compile ( self , ctx : commands . Context , compile_before_reload : bool ) - > None :
""" Set whether to compile modified files before reloading. """
await self . config . compile_before_reload . set ( compile_before_reload )
await ctx . send ( f " I { ' will ' if compile_before_reload else ' will not ' } compile modified files before hotreloading cogs. " )
2025-02-01 16:57:45 +00:00
@hotreload_group.command ( name = " list " ) # type: ignore
2025-01-26 15:18:05 +00:00
async def hotreload_list ( self , ctx : commands . Context ) - > None :
""" List the currently active observers. """
if not self . observers :
await ctx . send ( " No observers are currently active. " )
return
2025-02-01 16:57:45 +00:00
await ctx . send ( f " Currently active observers (If there are more than one of these, report an issue): { box ( humanize_list ( [ str ( o ) for o in self . observers ] , style = ' unit ' ) ) } " )
2025-01-26 15:18:05 +00:00
2025-01-25 18:51:27 -05:00
class HotReloadHandler ( RegexMatchingEventHandler ) :
""" Handler for file changes. """
2025-01-26 10:06:14 -05:00
def __init__ ( self , cog : HotReload , path : Path ) - > None :
2025-01-25 18:51:27 -05:00
super ( ) . __init__ ( regexes = [ r " .* \ .py$ " ] )
2025-01-26 10:06:14 -05:00
self . cog : HotReload = cog
2025-01-25 18:51:27 -05:00
self . path : Path = path
self . logger : RedTraceLogger = getLogger ( name = " red.SeaCogs.HotReload.Observer " )
2025-01-25 19:25:29 -05:00
def on_any_event ( self , event : FileSystemEvent ) - > None :
""" Handle filesystem events. """
2025-01-25 18:51:27 -05:00
if event . is_directory :
return
2025-01-25 19:25:29 -05:00
allowed_events = ( " moved " , " deleted " , " created " , " modified " )
if event . event_type not in allowed_events :
return
2025-02-01 16:57:45 +00:00
relative_src_path = Path ( str ( event . src_path ) ) . relative_to ( self . path )
2025-01-25 19:25:29 -05:00
src_package_name = relative_src_path . parts [ 0 ]
cogs_to_reload = [ src_package_name ]
if isinstance ( event , FileSystemMovedEvent ) :
dest = f " to { event . dest_path } "
2025-02-01 16:57:45 +00:00
relative_dest_path = Path ( str ( event . dest_path ) ) . relative_to ( self . path )
2025-01-26 00:31:02 +00:00
dest_package_name = relative_dest_path . parts [ 0 ]
if dest_package_name != src_package_name :
cogs_to_reload . append ( dest_package_name )
2025-01-25 19:25:29 -05:00
else :
dest = " "
2025-01-26 01:53:40 +00:00
self . logger . info ( " File %s has been %s %s . " , event . src_path , event . event_type , dest )
2025-01-25 19:25:29 -05:00
2025-01-29 23:24:19 +00:00
run_coroutine_threadsafe (
coro = self . reload_cogs (
cog_names = cogs_to_reload ,
2025-02-01 16:57:45 +00:00
paths = [ Path ( str ( p ) ) for p in ( event . src_path , getattr ( event , " dest_path " , None ) ) if p ] ,
2025-01-29 23:24:19 +00:00
) ,
loop = self . cog . bot . loop ,
)
async def reload_cogs ( self , cog_names : Sequence [ str ] , paths : Sequence [ Path ] ) - > None :
""" Reload modified cogs. """
if not self . compile_modified_files ( cog_names , paths ) :
return
2025-01-25 19:25:29 -05:00
2025-01-26 10:06:14 -05:00
core_logic = CoreLogic ( bot = self . cog . bot )
2025-01-26 01:53:40 +00:00
self . logger . info ( " Reloading cogs: %s " , humanize_list ( cog_names , style = " unit " ) )
2025-01-26 14:13:37 +00:00
await core_logic . _reload ( pkg_names = cog_names ) # noqa: SLF001 # We have to use this private method because there is no public API to reload other cogs
2025-01-26 01:53:40 +00:00
self . logger . info ( " Reloaded cogs: %s " , humanize_list ( cog_names , style = " unit " ) )
2025-01-26 10:06:14 -05:00
channel = self . cog . bot . get_channel ( await self . cog . config . notify_channel ( ) )
2025-02-01 16:57:45 +00:00
if channel and isinstance ( channel , discord . TextChannel ) :
2025-01-26 10:06:14 -05:00
await channel . send ( f " Reloaded cogs: { humanize_list ( cog_names , style = ' unit ' ) } " )
2025-01-29 23:24:19 +00:00
def compile_modified_files ( self , cog_names : Sequence [ str ] , paths : Sequence [ Path ] ) - > bool :
""" Compile modified files to ensure they are valid Python files. """
for path in paths :
if not path . exists ( ) or path . suffix != " .py " :
self . logger . debug ( " Path %s does not exist or does not point to a Python file. Skipping compilation step. " , path )
continue
try :
with NamedTemporaryFile ( ) as temp_file :
self . logger . debug ( " Attempting to compile %s " , path )
2025-02-01 16:57:45 +00:00
py_compile . compile ( file = str ( path ) , cfile = temp_file . name , doraise = True )
2025-01-29 23:24:19 +00:00
self . logger . debug ( " Successfully compiled %s " , path )
except py_compile . PyCompileError as e :
e . __suppress_context__ = True
self . logger . exception ( " %s failed to compile. Not reloading cogs %s . " , path , humanize_list ( cog_names , style = " unit " ) )
return False
except OSError :
self . logger . exception ( " Failed to create tempfile for compilation step. Skipping. " )
return True