From 8f75e6491f6cfb8d154560ea06813a4aeace912b Mon Sep 17 00:00:00 2001 From: Overlord Date: Sun, 7 Dec 2025 17:21:20 +0100 Subject: [PATCH] qol change --- ws/registry.go | 54 ++++++++++++++++++++++++++++++++----------------- ws/util.go | 31 ++++++---------------------- ws/websocket.go | 2 +- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/ws/registry.go b/ws/registry.go index 64be288..263fcc4 100644 --- a/ws/registry.go +++ b/ws/registry.go @@ -79,22 +79,6 @@ func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) // caller should use FlushChannelWithSender to perform actual sends } -// UnregisterMod : remove mod for a channel -func (r *Registry) UnregisterMod(channelID string) { - r.mu.RLock() - e := r.entries[channelID] - r.mu.RUnlock() - if e == nil { - return - } - e.mu.Lock() - defer e.mu.Unlock() - if e.Mod != nil && e.Mod.Conn != nil { - _ = e.Mod.Conn.Close() - } - e.Mod = nil -} - // RegisterBot : single connection for bot. after registration call FlushAllToBotWithSender func (r *Registry) RegisterBot(conn *websocket.Conn) { r.botMu.Lock() @@ -105,13 +89,45 @@ func (r *Registry) RegisterBot(conn *websocket.Conn) { r.botMu.Unlock() } +func (r *Registry) UnregisterMod(channelID string) { + r.mu.RLock() + e := r.entries[channelID] + r.mu.RUnlock() + if e == nil { + return + } + + e.mu.Lock() + modConn := e.Mod + e.Mod = nil + e.mu.Unlock() + + if modConn != nil && modConn.Conn != nil { + _ = modConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) + _ = modConn.Conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), + time.Now().Add(time.Second), + ) + _ = modConn.Conn.Close() + } +} + func (r *Registry) UnregisterBot() { r.botMu.Lock() - if r.bot != nil && r.bot.Conn != nil { - _ = r.bot.Conn.Close() - } + botConn := r.bot r.bot = nil r.botMu.Unlock() + + if botConn != nil && botConn.Conn != nil { + _ = botConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) + _ = botConn.Conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), + time.Now().Add(time.Second), + ) + _ = botConn.Conn.Close() + } } func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn func(*websocket.Conn, GatewayMessageOut) error) (delivered bool, queued bool, err error) { diff --git a/ws/util.go b/ws/util.go index 2226d64..38e0ac1 100644 --- a/ws/util.go +++ b/ws/util.go @@ -126,22 +126,13 @@ func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, return true } -func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, typ, channelId string) { +func (wsg *WebsocketGateway) unregisterConn(typ, channelId string) { if typ == "bot" { wsg.registry.UnregisterBot() - if conn != nil { - _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) - _ = conn.Close() - } return } wsg.registry.UnregisterMod(channelId) - - if conn != nil { - _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) - _ = conn.Close() - } } func (wsg *WebsocketGateway) closeAll() { @@ -150,24 +141,14 @@ func (wsg *WebsocketGateway) closeAll() { wsg.registry.UnregisterBot() wsg.registry.mu.RLock() - entries := make([]*ChannelEntry, 0, len(wsg.registry.entries)) - for _, e := range wsg.registry.entries { - entries = append(entries, e) + channelIDs := make([]string, 0, len(wsg.registry.entries)) + for channelID := range wsg.registry.entries { + channelIDs = append(channelIDs, channelID) } wsg.registry.mu.RUnlock() - for _, e := range entries { - e.mu.Lock() - modConn := e.Mod - if modConn != nil { - e.Mod = nil - } - e.mu.Unlock() - - if modConn != nil && modConn.Conn != nil { - _ = modConn.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) - _ = modConn.Conn.Close() - } + for _, channelID := range channelIDs { + wsg.registry.UnregisterMod(channelID) } } diff --git a/ws/websocket.go b/ws/websocket.go index e999b22..f01cf80 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -182,7 +182,7 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error func (wsg *WebsocketGateway) read(conn *websocket.Conn, _type, channelId string) { defer func() { - wsg.unregisterConn(conn, _type, channelId) + wsg.unregisterConn(_type, channelId) wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String()) }()