From 3d6774586eff40a135a7820d8595947b20730983 Mon Sep 17 00:00:00 2001 From: Overlord Date: Sun, 7 Dec 2025 17:11:37 +0100 Subject: [PATCH] Gateway working, beta --- bot.go | 195 ++++++++++++++++++++++++++++++++++ sim.go | 36 +++---- util/cache/cache.go | 91 ---------------- util/cache/conn_cache.go | 88 ---------------- util/cache/structs.go | 37 ------- util/queue/structs.go | 66 ------------ ws/handlers.go | 40 +++---- ws/registry.go | 219 +++++++++++++++++++++++++++++++++++++++ ws/structs.go | 50 +++++++-- ws/temp.go | 47 --------- ws/util.go | 119 ++++++++++++++++----- ws/validate.go | 27 ----- ws/websocket.go | 107 ++++++------------- 13 files changed, 616 insertions(+), 506 deletions(-) create mode 100644 bot.go delete mode 100644 util/cache/cache.go delete mode 100644 util/cache/conn_cache.go delete mode 100644 util/cache/structs.go delete mode 100644 util/queue/structs.go create mode 100644 ws/registry.go delete mode 100644 ws/temp.go delete mode 100644 ws/validate.go diff --git a/bot.go b/bot.go new file mode 100644 index 0000000..b4069ae --- /dev/null +++ b/bot.go @@ -0,0 +1,195 @@ +//go:build simbot + +package main + +import ( + "encoding/json" + "flag" + "log" + "net/url" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gorilla/websocket" +) + +const ( + gatewayURL = "ws://localhost:3333/sync" + apiKey = "gateway" + // must match the mod's channel id used by your mod simulator + channelID = "123456789" +) + +type Handshake struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +type BotHandshake struct { + ChannelId string `json:"channel_id"` // match gateway field exactly +} + +type GatewayAck struct { + Status string `json:"status"` + Type string `json:"type"` +} + +type User struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type Destination struct { + ID string `json:"channel_id,omitempty"` +} + +type GatewayMessageIn struct { + ID string `json:"id"` + MsgID string `json:"msg_id"` + Destination Destination `json:"destination,omitempty"` + Author User `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts time.Time `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` +} + +func main() { + var ( + botID = flag.String("bot", "sim-bot-1", "bot id") + sendAfter = flag.Duration("send-after", 0, "optional: send a bot->mod message after this delay (e.g. 2s)") + sendMsg = flag.String("msg", "Hello from bot!", "optional bot->mod test message content") + ) + flag.Parse() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) + + u, err := url.Parse(gatewayURL) + if err != nil { + log.Fatalf("Failed to parse URL: %v", err) + } + q := u.Query() + q.Set("api_key", apiKey) + u.RawQuery = q.Encode() + + log.Printf("Connecting to %s", u.String()) + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + // we intentionally don't defer conn.Close() here immediately; we'll do on shutdown + log.Println("Connected to gateway") + + // handle server pings by replying a Pong (safe) + conn.SetPingHandler(func(appData string) error { + log.Println("Received ping from server, sending pong") + return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) + }) + + // send bot handshake (must at least include bot_id) + bhs := BotHandshake{ChannelId: channelID} + data, err := json.Marshal(bhs) + if err != nil { + _ = conn.Close() + log.Fatalf("Failed to marshal bot handshake: %v", err) + } + hs := Handshake{Type: "bot", Data: data} + if err := conn.WriteJSON(hs); err != nil { + _ = conn.Close() + log.Fatalf("Failed to send handshake: %v", err) + } + log.Println("Handshake sent (bot)") + + // Instead of ReadJSON, read raw message so we can see whatever the server returns (or why it closes). + // This avoids a silent parsing failure if server returns non-JSON or closes immediately. + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + msgType, raw, err := conn.ReadMessage() + if err != nil { + _ = conn.Close() + log.Fatalf("Failed to read handshake response: %v", err) + } + _ = conn.SetReadDeadline(time.Time{}) // clear deadline + + if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { + log.Printf("Raw handshake reply: %s", string(raw)) + var ack GatewayAck + if err := json.Unmarshal(raw, &ack); err != nil { + log.Printf("Handshake reply is not JSON or unmarshal failed: %v", err) + } else { + log.Printf("Parsed ack: status=%q, type=%q", ack.Status, ack.Type) + } + } else { + log.Printf("Handshake reply was control frame or unexpected type=%d", msgType) + } + + // From here, start the normal read loop. + done := make(chan struct{}) + go func() { + defer close(done) + for { + msgType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + log.Printf("WebSocket error: %v", err) + } else { + log.Printf("Connection closed or read error: %v", err) + } + return + } + + if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { + log.Printf("Received from gateway: %s", string(message)) + } + } + }() + + // optionally send a bot->mod message after a delay + if *sendAfter > 0 { + go func() { + time.Sleep(*sendAfter) + msg := GatewayMessageIn{ + MsgID: "bot-msg-001", + ID: channelID, // bot reports channel id as ID + Author: User{ + ID: *botID, + Name: "SimBot", + }, + Content: *sendMsg, + Ts: time.Now().UTC(), + } + // set a write deadline + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send bot->mod test message: %v", err) + return + } + log.Printf("Sent bot->mod test message (channel=%s)", channelID) + _ = conn.SetWriteDeadline(time.Time{}) + }() + } + + log.Println("Bot simulator running. Press Ctrl+C to exit.") + + // Wait for interrupt or read loop done + for { + select { + case <-done: + log.Println("Connection closed by server") + _ = conn.Close() + return + case <-interrupt: + log.Println("Interrupt received, closing connection...") + // politely close + _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + select { + case <-done: + case <-time.After(time.Second): + } + _ = conn.Close() + return + } + } +} diff --git a/sim.go b/sim.go index 6530720..ce2a105 100644 --- a/sim.go +++ b/sim.go @@ -18,6 +18,8 @@ const ( gatewayURL = "ws://localhost:3333/sync" apiKey = "gateway" serverID = "test-server-001" + // THE CHANNEL ID the mod says it serves. Must match gateway expectation. + channelID = "123456789" ) type Handshake struct { @@ -25,8 +27,10 @@ type Handshake struct { Data json.RawMessage `json:"data"` } +// ModHandshake now includes ChannelID type ModHandshake struct { - ServerID string `json:"server_id"` + ServerID string `json:"server_id"` + ChannelID string `json:"channel_id"` } type GatewayAck struct { @@ -58,7 +62,6 @@ func main() { interrupt := make(chan os.Signal, 1) signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) - // Build WebSocket URL with API key u, err := url.Parse(gatewayURL) if err != nil { log.Fatalf("Failed to parse URL: %v", err) @@ -69,7 +72,6 @@ func main() { log.Printf("Connecting to %s", u.String()) - // Connect to WebSocket conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { log.Fatalf("Failed to connect: %v", err) @@ -78,7 +80,8 @@ func main() { log.Println("Connected to gateway") - // Set up ping handler - respond to pings from server + // respond to pings (server ping -> client must pong). Using SetPingHandler is fine, + // but WriteControl for Pong is acceptable too. conn.SetPingHandler(func(appData string) error { log.Println("Received ping from server, sending pong") err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) @@ -89,8 +92,11 @@ func main() { return nil }) - // Send handshake - modHS := ModHandshake{ServerID: serverID} + // Build and send handshake including channel id + modHS := ModHandshake{ + ServerID: serverID, + ChannelID: channelID, + } modHSData, err := json.Marshal(modHS) if err != nil { log.Fatalf("Failed to marshal mod handshake: %v", err) @@ -104,21 +110,17 @@ func main() { if err := conn.WriteJSON(handshake); err != nil { log.Fatalf("Failed to send handshake: %v", err) } - log.Println("Handshake sent") - // Read acknowledgment + // Read acknowledgment (some servers might not reply with JSON; handle errors) var ack GatewayAck if err := conn.ReadJSON(&ack); err != nil { log.Fatalf("Failed to read acknowledgment: %v", err) } + log.Printf("Received acknowledgment: status=%q, type=%q", ack.Status, ack.Type) - log.Printf("Received acknowledgment: status=%s, type=%s", ack.Status, ack.Type) - - // Channel for incoming messages done := make(chan struct{}) - // Read loop - handles incoming messages and processes control frames go func() { defer close(done) for { @@ -131,21 +133,19 @@ func main() { } return } - - // Only log text/binary messages (ping/pong handled by handlers) if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { log.Printf("Received from server: %s", string(message)) } } }() - // Optional: Send a test message after connecting - time.Sleep(2 * time.Second) + // Optional: send a test message after connecting + time.Sleep(1 * time.Second) testMsg := GatewayMessageIn{ MsgID: "test-msg-001", ID: serverID, Destination: Destination{ - ID: "123456789", + ID: channelID, }, Author: User{ ID: "player-uuid-123", @@ -163,7 +163,6 @@ func main() { log.Println("Connection established. Responding to pings. Press Ctrl+C to disconnect.") - // Wait for interrupt or connection close for { select { case <-done: @@ -172,7 +171,6 @@ func main() { case <-interrupt: log.Println("Interrupt received, closing connection...") - // Send close message err := conn.WriteMessage( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), diff --git a/util/cache/cache.go b/util/cache/cache.go deleted file mode 100644 index 57d682c..0000000 --- a/util/cache/cache.go +++ /dev/null @@ -1,91 +0,0 @@ -package cache - -func NewCache() *Cache { - c := &Cache{} - c.state.Store(&mappings{ - s2c: make(map[string]string), - c2s: make(map[string]string), - }) - - return c -} - -func (c *Cache) Set(serverId, channelId string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - next := current.clone() - - if oldCh, ok := next.s2c[serverId]; ok && oldCh != channelId { - delete(next.c2s, oldCh) - } - if oldSrv, ok := next.c2s[channelId]; ok && oldSrv != serverId { - delete(next.s2c, oldSrv) - } - - next.s2c[serverId] = channelId - next.c2s[channelId] = serverId - - c.state.Store(next) -} - -func (c *Cache) GetByServerId(serverId string) (string, bool) { - m := c.state.Load() - val, ok := m.s2c[serverId] - return val, ok -} - -func (c *Cache) GetByChannelId(channelId string) (string, bool) { - m := c.state.Load() - val, ok := m.c2s[channelId] - return val, ok -} - -func (c *Cache) RemoveByServerId(serverId string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - if _, ok := current.s2c[serverId]; !ok { - return - } - - next := current.clone() - if channelId, ok := next.s2c[serverId]; ok { - delete(next.s2c, serverId) - delete(next.c2s, channelId) - } - c.state.Store(next) -} - -func (c *Cache) RemoveByChannelId(channelId string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - if _, ok := current.c2s[channelId]; !ok { - return - } - - next := current.clone() - if serverId, ok := next.c2s[channelId]; ok { - delete(next.c2s, channelId) - delete(next.s2c, serverId) - } - c.state.Store(next) -} - -func (m *mappings) clone() *mappings { - newM := &mappings{ - s2c: make(map[string]string, len(m.s2c)), - c2s: make(map[string]string, len(m.c2s)), - } - for k, v := range m.s2c { - newM.s2c[k] = v - } - for k, v := range m.c2s { - newM.c2s[k] = v - } - return newM -} diff --git a/util/cache/conn_cache.go b/util/cache/conn_cache.go deleted file mode 100644 index a49a6fe..0000000 --- a/util/cache/conn_cache.go +++ /dev/null @@ -1,88 +0,0 @@ -package cache - -import ( - "github.com/gorilla/websocket" -) - -func NewConnectionCache() *ConnectionCache { - c := &ConnectionCache{} - c.state.Store(&connectionMappings{ - id2c: make(map[string]*conn), - }) - - return c -} - -func (c *ConnectionCache) Set(id string, connection *websocket.Conn, meta *ConnectionMetaData) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - next := current.clone() - - next.id2c[id] = &conn{ - connection: connection, - meta: meta, - } - - c.state.Store(next) -} - -func (c *ConnectionCache) GetById(id string) (*websocket.Conn, *ConnectionMetaData, bool) { - m := c.state.Load() - if conn, ok := m.id2c[id]; ok { - return conn.connection, conn.meta, true - } - - return nil, nil, false -} - -func (c *ConnectionCache) RemoveById(id string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - if _, ok := current.id2c[id]; !ok { - return - } - - next := current.clone() - delete(next.id2c, id) - - c.state.Store(next) -} - -func (c *ConnectionCache) Range(fn func(id string, connection *websocket.Conn, meta *ConnectionMetaData)) { - m := c.state.Load() - for id, conn := range m.id2c { - fn(id, conn.connection, conn.meta) - } -} - -func (c *ConnectionCache) GetAllConnections() []*websocket.Conn { - m := c.state.Load() - conns := make([]*websocket.Conn, 0, len(m.id2c)) - for _, conn := range m.id2c { - conns = append(conns, conn.connection) - } - return conns -} - -func (c *ConnectionCache) Clear() { - c.mu.Lock() - defer c.mu.Unlock() - - c.state.Store(&connectionMappings{ - id2c: make(map[string]*conn), - }) -} - -func (cm *connectionMappings) clone() *connectionMappings { - newCM := &connectionMappings{ - id2c: make(map[string]*conn, len(cm.id2c)), - } - for k, v := range cm.id2c { - newCM.id2c[k] = v - } - return newCM -} diff --git a/util/cache/structs.go b/util/cache/structs.go deleted file mode 100644 index 37aac37..0000000 --- a/util/cache/structs.go +++ /dev/null @@ -1,37 +0,0 @@ -package cache - -import ( - "sync" - "sync/atomic" - - "github.com/gorilla/websocket" -) - -type Cache struct { - mu sync.Mutex - state atomic.Pointer[mappings] -} - -type ConnectionCache struct { - mu sync.Mutex - state atomic.Pointer[connectionMappings] -} - -type mappings struct { - s2c map[string]string - c2s map[string]string -} - -type connectionMappings struct { - id2c map[string]*conn -} - -type ConnectionMetaData struct { - ConnectionType string // "mod" or "bot" - ID string // server_id or bot_id for logging -} - -type conn struct { - connection *websocket.Conn - meta *ConnectionMetaData -} diff --git a/util/queue/structs.go b/util/queue/structs.go deleted file mode 100644 index 0588606..0000000 --- a/util/queue/structs.go +++ /dev/null @@ -1,66 +0,0 @@ -package queue - -import ( - "sync" -) - -// Queue is a thin wrapper around a buffered channel. -type Queue[T any] struct { - ch chan T - closedMu sync.Mutex - closed bool - closedCh chan struct{} -} - -func NewQueue[T any](capacity int) *Queue[T] { - if capacity <= 0 { - panic("capacity > 0 required") - } - - return &Queue[T]{ - ch: make(chan T, capacity), - closedCh: make(chan struct{}), - } -} - -// Cap returns capacity. -func (q *Queue[T]) Cap() int { return cap(q.ch) } - -// Len returns current length (snapshot). -func (q *Queue[T]) Len() int { return len(q.ch) } - -// Enqueue returns immediately: true if enqueued, false otherwise. -func (q *Queue[T]) Enqueue(v T) bool { - select { - case q.ch <- v: - return true - default: - return false - } -} - -// Dequeue returns immediately (item, true) or (zero, false) if empty. -func (q *Queue[T]) Dequeue() (T, bool) { - var zero T - select { - case v := <-q.ch: - return v, true - default: - return zero, false - } -} - -// Close closes the queue. Further Enqueue attempts return ErrClosed. Consumers drain until channel empty then see ErrClosed. -func (q *Queue[T]) Close() { - q.closedMu.Lock() - - if q.closed { - q.closedMu.Unlock() - return - } - - q.closed = true - close(q.closedCh) - close(q.ch) - q.closedMu.Unlock() -} diff --git a/ws/handlers.go b/ws/handlers.go index d05b927..512bf9e 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -2,7 +2,6 @@ package ws import ( "encoding/json" - "homestead/homestead_gateway/util/cache" "net/http" "time" @@ -47,52 +46,53 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) return } - meta := cache.ConnectionMetaData{ConnectionType: handshake.Type} - switch handshake.Type { case "mod": var mhs ModHandshake - if err := json.Unmarshal(handshake.Data, &mhs); err != nil { wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400, true) wsg.logger.Warn("Malformed mod handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta.ID = mhs.ServerID - - if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { + if mhs.ServerID == "" || mhs.ChannelID == "" { + wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400, true) return } - wsg.registerConn(conn, meta, "mod") - wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) + if !wsg.registerConn(conn, "mod", mhs.ChannelID, mhs.ServerID) { + wsg.sendWebsocketError(conn, "Failed to register mod.", 500, true) + return + } - go wsg.read(conn, meta, "mod") + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}) + wsg.registry.FlushChannelWithSender(mhs.ChannelID, wsg.flush) + + wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String()) + go wsg.read(conn, "mod", mhs.ChannelID) case "bot": var bhs BotHandshake - if err := json.Unmarshal(handshake.Data, &bhs); err != nil { wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400, true) - wsg.logger.Warn("Malformed bot handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta.ID = bhs.BotID + if bhs.ChannelId == "" { + wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400, true) + return + } - if ok := wsg.registerConn(conn, meta, "bot"); !ok { + if !wsg.registerConn(conn, "bot", bhs.ChannelId, "") { wsg.sendWebsocketError(conn, "Bot already connected.", 409, true) return } - if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { - return - } + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}) + wsg.registry.FlushAllToBotWithSender(wsg.flush) - wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) - - go wsg.read(conn, meta, "bot") + wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String()) + go wsg.read(conn, "bot", bhs.ChannelId) default: wsg.sendWebsocketError(conn, "Unknown handshake.", 400, true) diff --git a/ws/registry.go b/ws/registry.go new file mode 100644 index 0000000..64be288 --- /dev/null +++ b/ws/registry.go @@ -0,0 +1,219 @@ +package ws + +import ( + "fmt" + "time" + + "github.com/gorilla/websocket" +) + +func (q *BoundedQueue) Enqueue(m GatewayMessageOut) bool { + q.mu.Lock() + defer q.mu.Unlock() + if q.capacity == 0 { + return false + } + if q.length < q.capacity { + q.buf[(q.start+q.length)%q.capacity] = m + q.length++ + return true + } + // overwrite oldest + q.buf[q.start] = m + q.start = (q.start + 1) % q.capacity + return true +} + +func (q *BoundedQueue) PopAll() []GatewayMessageOut { + q.mu.Lock() + defer q.mu.Unlock() + if q.length == 0 { + return nil + } + out := make([]GatewayMessageOut, 0, q.length) + for i := 0; i < q.length; i++ { + out = append(out, q.buf[(q.start+i)%q.capacity]) + } + q.start = 0 + q.length = 0 + return out +} + +func (q *BoundedQueue) Len() int { + q.mu.Lock() + defer q.mu.Unlock() + return q.length +} + +// + +func (r *Registry) getOrCreate(channel string) *ChannelEntry { + r.mu.RLock() + e := r.entries[channel] + r.mu.RUnlock() + if e != nil { + return e + } + + r.mu.Lock() + defer r.mu.Unlock() + if e = r.entries[channel]; e == nil { + e = newChannelEntry(channel, r.queueCap) + r.entries[channel] = e + } + return e +} + +// + +// RegisterMod : map channel_id -> mod conn (serverID) +func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) { + e := r.getOrCreate(channelID) + e.mu.Lock() + defer e.mu.Unlock() + if e.Mod != nil && e.Mod.Conn != nil { + _ = e.Mod.Conn.Close() + } + e.Mod = &ConnWrapper{Conn: conn, ServerID: serverID, LastSeen: time.Now()} + // flush queued bot->mod messages for this channel + // caller should use FlushChannelWithSender to perform actual sends +} + +// UnregisterMod : remove mod for a channel +func (r *Registry) UnregisterMod(channelID string) { + r.mu.RLock() + e := r.entries[channelID] + r.mu.RUnlock() + if e == nil { + return + } + e.mu.Lock() + defer e.mu.Unlock() + if e.Mod != nil && e.Mod.Conn != nil { + _ = e.Mod.Conn.Close() + } + e.Mod = nil +} + +// RegisterBot : single connection for bot. after registration call FlushAllToBotWithSender +func (r *Registry) RegisterBot(conn *websocket.Conn) { + r.botMu.Lock() + if r.bot != nil && r.bot.Conn != nil { + _ = r.bot.Conn.Close() + } + r.bot = &ConnWrapper{Conn: conn, LastSeen: time.Now()} + r.botMu.Unlock() +} + +func (r *Registry) UnregisterBot() { + r.botMu.Lock() + if r.bot != nil && r.bot.Conn != nil { + _ = r.bot.Conn.Close() + } + r.bot = nil + r.botMu.Unlock() +} + +func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn func(*websocket.Conn, GatewayMessageOut) error) (delivered bool, queued bool, err error) { + if out.Type == "mod" { + r.botMu.Lock() + b := r.bot + r.botMu.Unlock() + if b != nil && b.Conn != nil { + if err := sendOverConn(b.Conn, out); err == nil { + return true, false, nil + } + _ = b.Conn.Close() + r.UnregisterBot() + } + + e := r.getOrCreate(channelID) + e.mu.Lock() + enq := e.Queue.Enqueue(out) + e.mu.Unlock() + if !enq { + return false, false, fmt.Errorf("queue disabled") + } + return false, true, nil + } + + e := r.getOrCreate(channelID) + e.mu.Lock() + mod := e.Mod + e.mu.Unlock() + if mod != nil && mod.Conn != nil { + if err := sendOverConn(mod.Conn, out); err == nil { + return true, false, nil + } + _ = mod.Conn.Close() + r.UnregisterMod(channelID) + } + + e.mu.Lock() + enq := e.Queue.Enqueue(out) + e.mu.Unlock() + if !enq { + return false, false, fmt.Errorf("queue disabled") + } + return false, true, nil +} + +// + +func (r *Registry) FlushChannelWithSender(channelID string, sendOverConn func(*websocket.Conn, GatewayMessageOut) error) { + r.mu.RLock() + e := r.entries[channelID] + r.mu.RUnlock() + if e == nil { + return + } + e.mu.Lock() + if e.Mod == nil || e.Mod.Conn == nil { + e.mu.Unlock() + return + } + msgs := e.Queue.PopAll() + modConn := e.Mod.Conn + e.mu.Unlock() + + for _, m := range msgs { + if err := sendOverConn(modConn, m); err != nil { + // if send fails, re-enqueue (best-effort), drop-oldest logic applies + e.mu.Lock() + _ = e.Queue.Enqueue(m) + e.mu.Unlock() + } + } +} + +func (r *Registry) FlushAllToBotWithSender(sendOverConn func(*websocket.Conn, GatewayMessageOut) error) { + r.botMu.Lock() + b := r.bot + r.botMu.Unlock() + if b == nil || b.Conn == nil { + return + } + + r.mu.RLock() + entries := make([]*ChannelEntry, 0, len(r.entries)) + for _, e := range r.entries { + entries = append(entries, e) + } + r.mu.RUnlock() + + for _, e := range entries { + e.mu.Lock() + msgs := e.Queue.PopAll() + e.mu.Unlock() + if len(msgs) == 0 { + continue + } + for _, m := range msgs { + if err := sendOverConn(b.Conn, m); err != nil { + e.mu.Lock() + _ = e.Queue.Enqueue(m) + e.mu.Unlock() + } + } + } +} diff --git a/ws/structs.go b/ws/structs.go index 96ea5b0..4623240 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -2,9 +2,8 @@ package ws import ( "encoding/json" - "homestead/homestead_gateway/util/cache" - "homestead/homestead_gateway/util/queue" "log/slog" + "sync" "time" "github.com/gorilla/websocket" @@ -17,16 +16,46 @@ type WebsocketGateway struct { upgrader websocket.Upgrader - cache *cache.Cache - queue *queue.Queue[GatewayMessageOut] - - bot *websocket.Conn - conns *cache.ConnectionCache + registry *Registry logger *slog.Logger closeFn func() error } +// + +type Registry struct { + mu sync.RWMutex + entries map[string]*ChannelEntry + queueCap int + + botMu sync.Mutex + bot *ConnWrapper +} + +type ConnWrapper struct { + Conn *websocket.Conn + ServerID string // set for mods (the server_id) + LastSeen time.Time +} + +type BoundedQueue struct { + mu sync.Mutex + buf []GatewayMessageOut + start int + length int + capacity int +} + +type ChannelEntry struct { + mu sync.Mutex + Channel string + Mod *ConnWrapper + Queue *BoundedQueue +} + +// + type User struct { ID string `json:"id"` Name string `json:"name"` @@ -64,15 +93,18 @@ type GatewayAck struct { Type string `json:"type"` } +// + type Handshake struct { Type string `json:"type"` // "mod" or "bot" Data json.RawMessage `json:"data"` } type ModHandshake struct { - ServerID string `json:"server_id"` + ServerID string `json:"server_id"` + ChannelID string `json:"channel_id"` } type BotHandshake struct { - BotID string `json:"bot_id"` + ChannelId string `json:"channel_id"` } diff --git a/ws/temp.go b/ws/temp.go deleted file mode 100644 index 86ca50b..0000000 --- a/ws/temp.go +++ /dev/null @@ -1,47 +0,0 @@ -package ws - -//type LoggingModHandler struct { -// logger *slog.Logger -//} -// -//type LoggingBotHandler struct { -// logger *slog.Logger -//} -// -//func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { -// return &LoggingModHandler{logger: logger} -//} -// -//func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { -// return &LoggingBotHandler{logger: logger} -//} -// -//func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayMessageIn) error { -// // For now, just log and pretend it's being forwarded -// // TODO: Look up channel_id from database using server -// // TODO: Forward to bot connection(s) -// -// fwd := GatewayMessageOut{ -// Type: "mod", -// ChannelID: "TODO", // will come from database lookup -// Author: msg.Author, -// Content: msg.Content, -// Meta: msg.Meta, -// Ts: msg.Ts, -// ReceivedAt: msg.ReceivedAt, -// ForwardedAt: time.Now().UTC(), -// } -// -// _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) -// -// if err := conn.WriteJSON(fwd); err != nil { -// _ = conn.Close() -// return err -// } -// -// b, _ := json.Marshal(fwd) -// h.logger.Info("received mod message", "msg_id", msg.MsgID, "server", msg.Server, "Author", msg.Author.Name, "content", msg.Content) -// h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) -// -// return nil -//} diff --git a/ws/util.go b/ws/util.go index 6003517..2226d64 100644 --- a/ws/util.go +++ b/ws/util.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "errors" - "homestead/homestead_gateway/util/cache" "log/slog" "net/http" + "strings" "time" "github.com/gorilla/websocket" @@ -36,6 +36,11 @@ func (wsg *WebsocketGateway) deafen(srv *http.Server) { // responses +func (wsg *WebsocketGateway) flush(c *websocket.Conn, m GatewayMessageOut) error { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + return c.WriteJSON(m) +} + func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) @@ -94,15 +99,6 @@ func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { return !(apiKey == "" || apiKey != wsg.apiKey) } -func writeJSONSafe(c *websocket.Conn, v interface{}) error { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteJSON(v); err != nil { - // caller handles logging - return err - } - return nil -} - func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -113,43 +109,112 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { // connections -func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) bool { +func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool { if typ == "bot" { - if wsg.bot != nil { + wsg.registry.botMu.Lock() + if wsg.registry.bot != nil && wsg.registry.bot.Conn != nil { + wsg.registry.botMu.Unlock() return false } + wsg.registry.botMu.Unlock() - wsg.bot = conn + wsg.registry.RegisterBot(conn) return true } - wsg.conns.Set(meta.ID, conn, &meta) + wsg.registry.RegisterMod(channelId, serverId, conn) return true } -func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) { +func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, typ, channelId string) { if typ == "bot" { - _ = wsg.bot.Close() - wsg.bot = nil + wsg.registry.UnregisterBot() + if conn != nil { + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) + _ = conn.Close() + } return } - wsg.conns.RemoveById(meta.ID) - _ = conn.Close() + wsg.registry.UnregisterMod(channelId) + + if conn != nil { + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) + _ = conn.Close() + } } func (wsg *WebsocketGateway) closeAll() { wsg.logger.Info("Closing all websocket connections.") - if wsg.bot != nil { - _ = wsg.bot.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) - _ = wsg.bot.Close() + wsg.registry.UnregisterBot() + + wsg.registry.mu.RLock() + entries := make([]*ChannelEntry, 0, len(wsg.registry.entries)) + for _, e := range wsg.registry.entries { + entries = append(entries, e) + } + wsg.registry.mu.RUnlock() + + for _, e := range entries { + e.mu.Lock() + modConn := e.Mod + if modConn != nil { + e.Mod = nil + } + e.mu.Unlock() + + if modConn != nil && modConn.Conn != nil { + _ = modConn.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) + _ = modConn.Conn.Close() + } + } +} + +// + +func NewRegistry(queueCap int) *Registry { + return &Registry{ + entries: make(map[string]*ChannelEntry), + queueCap: queueCap, + } +} + +func NewBoundedQueue(cap int) *BoundedQueue { + if cap <= 0 { + cap = 128 + } + return &BoundedQueue{buf: make([]GatewayMessageOut, cap), capacity: cap} +} + +func newChannelEntry(channel string, cap int) *ChannelEntry { + return &ChannelEntry{ + Channel: channel, + Queue: NewBoundedQueue(cap), + } +} + +// + +func (m *GatewayMessageIn) Validate() error { + if strings.TrimSpace(m.ID) == "" { + return errors.New("id missing") + } + if strings.TrimSpace(m.MsgID) == "" { + return errors.New("msg_id missing") + } + if strings.TrimSpace(m.Author.ID) == "" { + return errors.New("author.id missing") + } + if strings.TrimSpace(m.Content) == "" { + return errors.New("content missing") } - wsg.conns.Range(func(id string, c *websocket.Conn, meta *cache.ConnectionMetaData) { - _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) - _ = c.Close() - }) + if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" { + return errors.New("destination.channel_id missing") + } - wsg.conns.Clear() + return nil } + +func (c *ConnWrapper) Alive() bool { return c != nil && c.Conn != nil } diff --git a/ws/validate.go b/ws/validate.go deleted file mode 100644 index 1431e47..0000000 --- a/ws/validate.go +++ /dev/null @@ -1,27 +0,0 @@ -package ws - -import ( - "errors" - "strings" -) - -func (m *GatewayMessageIn) Validate() error { - if strings.TrimSpace(m.ID) == "" { - return errors.New("id missing") - } - if strings.TrimSpace(m.MsgID) == "" { - return errors.New("msg_id missing") - } - if strings.TrimSpace(m.Author.ID) == "" { - return errors.New("author.id missing") - } - if strings.TrimSpace(m.Content) == "" { - return errors.New("content missing") - } - - if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" { - return errors.New("destination.channel_id missing") - } - - return nil -} diff --git a/ws/websocket.go b/ws/websocket.go index c9e8d64..e999b22 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -4,9 +4,7 @@ import ( "context" "encoding/json" "fmt" - "homestead/homestead_gateway/util/cache" "homestead/homestead_gateway/util/config" - "homestead/homestead_gateway/util/queue" "log/slog" "net" "net/http" @@ -19,13 +17,11 @@ import ( func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - port: cfg.HttpPort, - apiKey: cfg.Websocket, - cache: cache.NewCache(), - queue: queue.NewQueue[GatewayMessageOut](32), - conns: cache.NewConnectionCache(), + logger: logger, + closeFn: closefn, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + registry: NewRegistry(32), upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -184,9 +180,9 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // } //} -func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMetaData, _type string) { +func (wsg *WebsocketGateway) read(conn *websocket.Conn, _type, channelId string) { defer func() { - wsg.unregisterConn(conn, meta, _type) + wsg.unregisterConn(conn, _type, channelId) wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String()) }() @@ -235,70 +231,16 @@ func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMet continue } - var ok bool - var destConn *websocket.Conn - - switch message.Type { - case "mod": - if wsg.bot != nil { - ok = true - destConn = wsg.bot - } else { - ok = false - } - case "bot": - var id string - id, ok = wsg.cache.GetByChannelId(message.ID) - if ok { - var dest *websocket.Conn - dest, _, ok = wsg.conns.GetById(id) - if ok { - destConn = dest - } else { - wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) - wsg.logger.Error("Invalid cache structure.", "remote", conn.RemoteAddr().String(), "id", id) - return - } - } - default: - panic("invalid message type") - } - - if ok { - if destConn == nil { - wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) - wsg.logger.Error("Destination connection unavailable.", "remote", conn.RemoteAddr().String()) - return - } - - err = wsg.sendWebsocketResponse(destConn, GatewayMessageOut{ - Type: message.Type, - ID: message.Destination.ID, - Author: message.Author, - Content: message.Content, - Meta: message.Meta, - Ts: message.Ts, - ReceivedAt: message.ReceivedAt, - ForwardedAt: time.Now().UTC(), - }) - - if err != nil { - _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) - wsg.logger.Error("Failed to forward message.", "remote", conn.RemoteAddr().String(), "err", err) - continue - } - - _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "completed", Type: message.Type}) - continue - } - - if message.Type == "mod" { - wsg.cache.Set(message.ID, message.Destination.ID) + var outID string + if _type == "mod" { + outID = channelId + } else { + outID = message.ID } out := GatewayMessageOut{ Type: message.Type, - ID: message.Destination.ID, + ID: outID, Author: message.Author, Content: message.Content, Meta: message.Meta, @@ -307,13 +249,28 @@ func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMet ForwardedAt: time.Now().UTC(), } - queued := wsg.queue.Enqueue(out) - if !queued { + delivered, queued, err := wsg.registry.Send(out.ID, out, func(c *websocket.Conn, m GatewayMessageOut) error { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + return c.WriteJSON(m) + }) + + if err != nil { + wsg.logger.Error("registry send error", "err", err) _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) - wsg.logger.Warn("Failed to queue message.", "remote", conn.RemoteAddr().String()) continue } - _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "queued", Type: message.Type}) + if delivered { + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "completed", Type: message.Type}) + continue + } + + if queued { + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "queued", Type: message.Type}) + continue + } + + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) + } }