From 8f7db6256bc0a149e22885de6ce089d393a22748 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 22:30:36 +0100 Subject: [PATCH] eod commit --- controller/controller.go | 2 +- main.go | 6 +++ util/cache/conn_cache.go | 88 +++++++++++++++++++++++++++++++++++ util/cache/structs.go | 21 +++++++++ ws/handlers.go | 7 +-- ws/structs.go | 18 +++----- ws/temp.go | 7 +-- ws/util.go | 26 +++++------ ws/websocket.go | 99 +++++++++++++++++++--------------------- 9 files changed, 190 insertions(+), 84 deletions(-) create mode 100644 util/cache/conn_cache.go diff --git a/controller/controller.go b/controller/controller.go index 285aa66..86873b4 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -16,7 +16,7 @@ func NewGatewayController(cfg config.Config) GatewayController { botHandler := ws.NewLoggingBotHandler(wsl) return GatewayController{ - Websocket: ws.NewWsGateway(cfg.Gateway, wsl, wCloseFn, modHandler, botHandler), + Websocket: ws.NewWebsocketGateway(cfg.Gateway, wsl, wCloseFn, modHandler, botHandler), HttpServer: HttpGateway{}, } } diff --git a/main.go b/main.go index fa5670f..f868478 100644 --- a/main.go +++ b/main.go @@ -19,3 +19,9 @@ func main() { panic(err) } } + +/** TODO + +- queue for messages, both ways (Ack "queued" instead of "completed"), filled queue drops oldest entry + +*/ diff --git a/util/cache/conn_cache.go b/util/cache/conn_cache.go new file mode 100644 index 0000000..a49a6fe --- /dev/null +++ b/util/cache/conn_cache.go @@ -0,0 +1,88 @@ +package cache + +import ( + "github.com/gorilla/websocket" +) + +func NewConnectionCache() *ConnectionCache { + c := &ConnectionCache{} + c.state.Store(&connectionMappings{ + id2c: make(map[string]*conn), + }) + + return c +} + +func (c *ConnectionCache) Set(id string, connection *websocket.Conn, meta *ConnectionMetaData) { + c.mu.Lock() + defer c.mu.Unlock() + + current := c.state.Load() + next := current.clone() + + next.id2c[id] = &conn{ + connection: connection, + meta: meta, + } + + c.state.Store(next) +} + +func (c *ConnectionCache) GetById(id string) (*websocket.Conn, *ConnectionMetaData, bool) { + m := c.state.Load() + if conn, ok := m.id2c[id]; ok { + return conn.connection, conn.meta, true + } + + return nil, nil, false +} + +func (c *ConnectionCache) RemoveById(id string) { + c.mu.Lock() + defer c.mu.Unlock() + + current := c.state.Load() + if _, ok := current.id2c[id]; !ok { + return + } + + next := current.clone() + delete(next.id2c, id) + + c.state.Store(next) +} + +func (c *ConnectionCache) Range(fn func(id string, connection *websocket.Conn, meta *ConnectionMetaData)) { + m := c.state.Load() + for id, conn := range m.id2c { + fn(id, conn.connection, conn.meta) + } +} + +func (c *ConnectionCache) GetAllConnections() []*websocket.Conn { + m := c.state.Load() + conns := make([]*websocket.Conn, 0, len(m.id2c)) + for _, conn := range m.id2c { + conns = append(conns, conn.connection) + } + return conns +} + +func (c *ConnectionCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.state.Store(&connectionMappings{ + id2c: make(map[string]*conn), + }) +} + +func (cm *connectionMappings) clone() *connectionMappings { + newCM := &connectionMappings{ + id2c: make(map[string]*conn, len(cm.id2c)), + } + for k, v := range cm.id2c { + newCM.id2c[k] = v + } + return newCM +} diff --git a/util/cache/structs.go b/util/cache/structs.go index e484c32..37aac37 100644 --- a/util/cache/structs.go +++ b/util/cache/structs.go @@ -3,6 +3,8 @@ package cache import ( "sync" "sync/atomic" + + "github.com/gorilla/websocket" ) type Cache struct { @@ -10,7 +12,26 @@ type Cache struct { state atomic.Pointer[mappings] } +type ConnectionCache struct { + mu sync.Mutex + state atomic.Pointer[connectionMappings] +} + type mappings struct { s2c map[string]string c2s map[string]string } + +type connectionMappings struct { + id2c map[string]*conn +} + +type ConnectionMetaData struct { + ConnectionType string // "mod" or "bot" + ID string // server_id or bot_id for logging +} + +type conn struct { + connection *websocket.Conn + meta *ConnectionMetaData +} diff --git a/ws/handlers.go b/ws/handlers.go index c387568..5b0c696 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -2,6 +2,7 @@ package ws import ( "encoding/json" + "homestead/homestead_gateway/util/cache" "net/http" "time" @@ -46,7 +47,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - meta := connectionMetaData{connectionType: handshake.Type} + meta := cache.ConnectionMetaData{ConnectionType: handshake.Type} switch handshake.Type { case "mod": @@ -58,7 +59,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - meta.id = mhs.ServerID + meta.ID = mhs.ServerID if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { return @@ -78,7 +79,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - meta.id = bhs.BotID + meta.ID = bhs.BotID if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { return diff --git a/ws/structs.go b/ws/structs.go index bb70593..1bbe67e 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -1,11 +1,9 @@ package ws import ( - "context" "encoding/json" "homestead/homestead_gateway/util/cache" "log/slog" - "sync" "time" "github.com/gorilla/websocket" @@ -17,9 +15,10 @@ type WebsocketGateway struct { bodySizeBytes int64 upgrader websocket.Upgrader - connsMu sync.Mutex - conns map[*websocket.Conn]connectionMetaData - cache cache.Cache + + cache *cache.Cache + botConn *websocket.Conn + conns *cache.ConnectionCache modHandler ModHandler botHandler BotHandler @@ -108,14 +107,9 @@ type BotHandshake struct { } type ModHandler interface { - Handle(ctx context.Context, msg GatewayModMessageIn) error + Handle(conn *websocket.Conn, msg GatewayModMessageIn) error } type BotHandler interface { - Handle(ctx context.Context, msg GatewayBotMessageIn) error -} - -type connectionMetaData struct { - connectionType string // "mod" or "bot" - id string // server_id or bot_id for logging + Handle(conn *websocket.Conn, msg GatewayBotMessageIn) error } diff --git a/ws/temp.go b/ws/temp.go index d40f5c4..c62057b 100644 --- a/ws/temp.go +++ b/ws/temp.go @@ -1,10 +1,11 @@ package ws import ( - "context" "encoding/json" "log/slog" "time" + + "github.com/gorilla/websocket" ) type LoggingModHandler struct { @@ -23,7 +24,7 @@ func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { return &LoggingBotHandler{logger: logger} } -func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) error { +func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayModMessageIn) error { // For now, just log and pretend it's being forwarded // TODO: Look up channel_id from database using server // TODO: Forward to bot connection(s) @@ -46,7 +47,7 @@ func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) return nil } -func (h *LoggingBotHandler) Handle(ctx context.Context, msg GatewayBotMessageIn) error { +func (h *LoggingBotHandler) Handle(conn *websocket.Conn, msg GatewayBotMessageIn) error { // For now, just log and pretend it's being forwarded // TODO: Look up server_id from database using channel_id // TODO: Forward to mod connection(s) diff --git a/ws/util.go b/ws/util.go index 4822779..f695de3 100644 --- a/ws/util.go +++ b/ws/util.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "homestead/homestead_gateway/util/cache" "log/slog" "net/http" "time" @@ -41,6 +42,11 @@ func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string _ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code}) } +func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) { + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _ = conn.WriteMessage(websocket.PingMessage, nil) +} + func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int) { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code}) @@ -105,26 +111,20 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { // connections -func (wsg *WebsocketGateway) registerConn(c *websocket.Conn, meta connectionMetaData) { - wsg.connsMu.Lock() - wsg.conns[c] = meta - wsg.connsMu.Unlock() +func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData) { + wsg.conns.Set(meta.ID, conn, &meta) } -func (wsg *WebsocketGateway) unregisterConn(c *websocket.Conn) { - wsg.connsMu.Lock() - delete(wsg.conns, c) - wsg.connsMu.Unlock() +func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData) { + wsg.conns.RemoveById(meta.ID) + _ = conn.Close() } func (wsg *WebsocketGateway) closeAll() { - wsg.connsMu.Lock() - defer wsg.connsMu.Unlock() - wsg.logger.Info("Closing all websocket connections.") - for c := range wsg.conns { + wsg.conns.Range(func(id string, c *websocket.Conn, meta *cache.ConnectionMetaData) { _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) _ = c.Close() - } + }) } diff --git a/ws/websocket.go b/ws/websocket.go index 4422c4b..a5ba800 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "homestead/homestead_gateway/util/cache" "homestead/homestead_gateway/util/config" "log/slog" "net" @@ -15,11 +16,16 @@ import ( "github.com/gorilla/websocket" ) -func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { +func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - apiKey: cfg.Websocket, + logger: logger, + closeFn: closefn, + modHandler: modH, + botHandler: botH, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + cache: cache.NewCache(), + conns: cache.NewConnectionCache(), upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -27,11 +33,7 @@ func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() return true // local by default; change for production }, }, - conns: make(map[*websocket.Conn]connectionMetaData), bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, - port: cfg.HttpPort, - modHandler: modH, - botHandler: botH, } } @@ -69,35 +71,29 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // -func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaData) { +func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { defer func() { - wsg.unregisterConn(c) - _ = c.Close() - wsg.logger.Info("mod disconnected", "server_id", meta.id, "remote", c.RemoteAddr().String()) + wsg.unregisterConn(conn, meta) + wsg.logger.Info("Mod-Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) }() - pingTicker := time.NewTicker(30 * time.Second) - defer pingTicker.Stop() + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() - // Send pings in a separate goroutine go func() { - for range pingTicker.C { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "server_id", meta.id, "err", err) - return - } - wsg.logger.Debug("sent ping to mod", "server_id", meta.id) + for range ticker.C { + wsg.sendWebsocketPing(conn) } }() for { - typ, data, err := c.ReadMessage() + typ, data, err := conn.ReadMessage() + if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("unexpected mod close", "server_id", meta.id, "err", err) + wsg.logger.Warn("unexpected mod close", "server_id", meta.ID, "err", err) } else { - wsg.logger.Debug("mod read error", "server_id", meta.id, "err", err) + wsg.logger.Debug("mod read error", "server_id", meta.ID, "err", err) } return } @@ -108,34 +104,33 @@ func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaD var msg GatewayModMessageIn if err := json.Unmarshal(data, &msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - wsg.logger.Warn("invalid json from mod", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) + wsg.logger.Warn("invalid json from mod", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } msg.ReceivedAt = time.Now().UTC() if err := msg.Validate(); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - wsg.logger.Warn("mod message validation failed", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) + wsg.logger.Warn("mod message validation failed", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } // Handle the message (forward to bot, enrich, etc.) - if err := wsg.modHandler.Handle(context.Background(), msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) - wsg.logger.Error("mod handler error", "server_id", meta.id, "err", err) + if err := wsg.modHandler.Handle(conn, msg); err != nil { + _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) + wsg.logger.Error("mod handler error", "server_id", meta.ID, "err", err) continue } - _ = writeJSONSafe(c, map[string]string{"status": "ok"}) + _ = writeJSONSafe(conn, map[string]string{"status": "completed"}) // or "queued" } } -func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaData) { +func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { defer func() { - wsg.unregisterConn(c) - _ = c.Close() - wsg.logger.Info("bot disconnected", "bot_id", meta.id, "remote", c.RemoteAddr().String()) + wsg.unregisterConn(conn, meta) + wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) }() pingTicker := time.NewTicker(30 * time.Second) @@ -144,22 +139,22 @@ func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaD // Send pings in a separate goroutine go func() { for range pingTicker.C { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + wsg.logger.Debug("write ping failed", "bot_id", meta.ID, "err", err) return } - wsg.logger.Debug("sent ping to bot", "bot_id", meta.id) + wsg.logger.Debug("sent ping to bot", "bot_id", meta.ID) } }() for { - typ, data, err := c.ReadMessage() + typ, data, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("unexpected bot close", "bot_id", meta.id, "err", err) + wsg.logger.Warn("unexpected bot close", "bot_id", meta.ID, "err", err) } else { - wsg.logger.Debug("bot read error", "bot_id", meta.id, "err", err) + wsg.logger.Debug("bot read error", "bot_id", meta.ID, "err", err) } return } @@ -170,25 +165,25 @@ func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaD var msg GatewayBotMessageIn if err := json.Unmarshal(data, &msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - wsg.logger.Warn("invalid json from bot", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) + wsg.logger.Warn("invalid json from bot", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } msg.ReceivedAt = time.Now().UTC() if err := msg.Validate(); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - wsg.logger.Warn("bot message validation failed", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) + wsg.logger.Warn("bot message validation failed", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } // Handle the message (forward to mod, enrich, etc.) - if err := wsg.botHandler.Handle(context.Background(), msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) - wsg.logger.Error("bot handler error", "bot_id", meta.id, "err", err) + if err := wsg.botHandler.Handle(conn, msg); err != nil { + _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) + wsg.logger.Error("bot handler error", "bot_id", meta.ID, "err", err) continue } - _ = writeJSONSafe(c, map[string]string{"status": "ok"}) + _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) } }