diff --git a/config.toml b/config.toml index b41f343..8850d7b 100644 --- a/config.toml +++ b/config.toml @@ -7,10 +7,4 @@ rotation = 3 # in days http_port = 3333 websocket = "gateway" body_size = 2 # in MB -queue_max = 8192 - -[database] -host_dsn = "" -username = "" -password = "" -database = "" +queue_max = 32 diff --git a/util/config/structs.go b/util/config/structs.go index 4700c84..5437791 100644 --- a/util/config/structs.go +++ b/util/config/structs.go @@ -3,9 +3,8 @@ package config import "log/slog" type Config struct { - Log LogConfig `toml:"log"` - Gateway GatewayConfig `toml:"gateway"` - Database DatabaseConfig `toml:"database"` + Log LogConfig `toml:"log"` + Gateway GatewayConfig `toml:"gateway"` } type GatewayConfig struct { @@ -20,10 +19,3 @@ type LogConfig struct { Directory string `toml:"directory"` Rotation int `toml:"rotation"` } - -type DatabaseConfig struct { - HostDSN string `toml:"host_dsn"` - Username string `toml:"username"` - Password string `toml:"password"` - Database string `toml:"database"` -} diff --git a/ws/registry.go b/ws/registry.go index 263fcc4..be8d6a5 100644 --- a/ws/registry.go +++ b/ws/registry.go @@ -13,14 +13,17 @@ func (q *BoundedQueue) Enqueue(m GatewayMessageOut) bool { if q.capacity == 0 { return false } + if q.length < q.capacity { q.buf[(q.start+q.length)%q.capacity] = m q.length++ return true } + // overwrite oldest q.buf[q.start] = m q.start = (q.start + 1) % q.capacity + return true } @@ -30,12 +33,15 @@ func (q *BoundedQueue) PopAll() []GatewayMessageOut { if q.length == 0 { return nil } + out := make([]GatewayMessageOut, 0, q.length) for i := 0; i < q.length; i++ { out = append(out, q.buf[(q.start+i)%q.capacity]) } + q.start = 0 q.length = 0 + return out } @@ -64,6 +70,19 @@ func (r *Registry) getOrCreate(channel string) *ChannelEntry { return e } +func (r *Registry) ForEach(cb func(channelID string)) { + r.mu.RLock() + ids := make([]string, 0, len(r.entries)) + for id := range r.entries { + ids = append(ids, id) + } + r.mu.RUnlock() + + for _, id := range ids { + cb(id) + } +} + // // RegisterMod : map channel_id -> mod conn (serverID) @@ -71,9 +90,12 @@ func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) e := r.getOrCreate(channelID) e.mu.Lock() defer e.mu.Unlock() + if e.Mod != nil && e.Mod.Conn != nil { + _ = e.Mod.Conn.Close() } + e.Mod = &ConnWrapper{Conn: conn, ServerID: serverID, LastSeen: time.Now()} // flush queued bot->mod messages for this channel // caller should use FlushChannelWithSender to perform actual sends @@ -82,9 +104,11 @@ func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) // RegisterBot : single connection for bot. after registration call FlushAllToBotWithSender func (r *Registry) RegisterBot(conn *websocket.Conn) { r.botMu.Lock() + if r.bot != nil && r.bot.Conn != nil { _ = r.bot.Conn.Close() } + r.bot = &ConnWrapper{Conn: conn, LastSeen: time.Now()} r.botMu.Unlock() } @@ -103,13 +127,7 @@ func (r *Registry) UnregisterMod(channelID string) { e.mu.Unlock() if modConn != nil && modConn.Conn != nil { - _ = modConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) - _ = modConn.Conn.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), - time.Now().Add(time.Second), - ) - _ = modConn.Conn.Close() + closeConn(modConn.Conn) } } @@ -120,13 +138,7 @@ func (r *Registry) UnregisterBot() { r.botMu.Unlock() if botConn != nil && botConn.Conn != nil { - _ = botConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) - _ = botConn.Conn.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), - time.Now().Add(time.Second), - ) - _ = botConn.Conn.Close() + closeConn(botConn.Conn) } } @@ -135,6 +147,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu r.botMu.Lock() b := r.bot r.botMu.Unlock() + if b != nil && b.Conn != nil { if err := sendOverConn(b.Conn, out); err == nil { return true, false, nil @@ -147,6 +160,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu e.mu.Lock() enq := e.Queue.Enqueue(out) e.mu.Unlock() + if !enq { return false, false, fmt.Errorf("queue disabled") } @@ -157,6 +171,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu e.mu.Lock() mod := e.Mod e.mu.Unlock() + if mod != nil && mod.Conn != nil { if err := sendOverConn(mod.Conn, out); err == nil { return true, false, nil @@ -168,6 +183,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu e.mu.Lock() enq := e.Queue.Enqueue(out) e.mu.Unlock() + if !enq { return false, false, fmt.Errorf("queue disabled") } @@ -183,11 +199,13 @@ func (r *Registry) FlushChannelWithSender(channelID string, sendOverConn func(*w if e == nil { return } + e.mu.Lock() if e.Mod == nil || e.Mod.Conn == nil { e.mu.Unlock() return } + msgs := e.Queue.PopAll() modConn := e.Mod.Conn e.mu.Unlock() @@ -221,6 +239,7 @@ func (r *Registry) FlushAllToBotWithSender(sendOverConn func(*websocket.Conn, Ga e.mu.Lock() msgs := e.Queue.PopAll() e.mu.Unlock() + if len(msgs) == 0 { continue } diff --git a/ws/structs.go b/ws/structs.go index 4623240..f601775 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -67,19 +67,19 @@ type Destination struct { type GatewayMessageIn struct { Type string - ID string `json:"id"` // where am I from (channel_id or server_id) - MsgID string `json:"msg_id"` // msg id - Destination Destination `json:"destination,omitempty"` // where do I wanna go (channel_id or empty if from Bot) - Author User `json:"author"` // who sent the message - Content string `json:"content"` // message content - Meta map[string]interface{} `json:"meta,omitempty"` // additional metadata - Ts time.Time `json:"ts,omitempty"` // timestamp - ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) + ID string `json:"id"` + MsgID string `json:"msg_id"` + Destination Destination `json:"destination,omitempty"` + Author User `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts time.Time `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` } type GatewayMessageOut struct { - Type string `json:"type"` // "mod"|"bot" - ID string `json:"channel_id,omitempty"` // message.Destination.ID + Type string `json:"type"` + ID string `json:"channel_id,omitempty"` Author User `json:"author"` Content string `json:"content"` Meta map[string]interface{} `json:"meta,omitempty"` diff --git a/ws/util.go b/ws/util.go index 38e0ac1..3e9a8b7 100644 --- a/ws/util.go +++ b/ws/util.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "log/slog" "net/http" "strings" "time" @@ -99,16 +98,26 @@ func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { return !(apiKey == "" || apiKey != wsg.apiKey) } -func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { +func (wsg *WebsocketGateway) loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) - logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start)) + wsg.logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start)) }) } // connections +func closeConn(conn *websocket.Conn) { + _ = conn.SetWriteDeadline(time.Now().Add(time.Second)) + _ = conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), + time.Now().Add(time.Second), + ) + _ = conn.Close() +} + func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool { if typ == "bot" { wsg.registry.botMu.Lock() @@ -140,20 +149,23 @@ func (wsg *WebsocketGateway) closeAll() { wsg.registry.UnregisterBot() - wsg.registry.mu.RLock() - channelIDs := make([]string, 0, len(wsg.registry.entries)) - for channelID := range wsg.registry.entries { - channelIDs = append(channelIDs, channelID) - } - wsg.registry.mu.RUnlock() - - for _, channelID := range channelIDs { + wsg.registry.ForEach(func(channelID string) { wsg.registry.UnregisterMod(channelID) - } + }) } // +func NewUpgrader() websocket.Upgrader { + return websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // local by default; change for production + }, + } +} + func NewRegistry(queueCap int) *Registry { return &Registry{ entries: make(map[string]*ChannelEntry), diff --git a/ws/websocket.go b/ws/websocket.go index f01cf80..d392fac 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -17,18 +17,12 @@ import ( func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - port: cfg.HttpPort, - apiKey: cfg.Websocket, - registry: NewRegistry(32), - upgrader: websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true // local by default; change for production - }, - }, + logger: logger, + closeFn: closefn, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + upgrader: NewUpgrader(), + registry: NewRegistry(cfg.QueueSize), bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, } } @@ -47,7 +41,7 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error srv := &http.Server{ Addr: listenAddr, - Handler: loggingMiddleware(wsg.logger, mux), + Handler: wsg.loggingMiddleware(mux), BaseContext: func(l net.Listener) context.Context { return ctx }, } errCh := make(chan error, 1) @@ -65,121 +59,6 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // -//func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { -// defer func() { -// wsg.unregisterConn(conn, meta, "mod") -// wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) -// }() -// -// ticker := time.NewTicker(30 * time.Second) -// defer ticker.Stop() -// -// go func() { -// for range ticker.C { -// wsg.sendWebsocketPing(conn) -// } -// }() -// -// for { -// typ, data, err := conn.ReadMessage() -// -// if err != nil { -// if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { -// wsg.logger.Warn("Mod-Client unexpectedly closed the connection.", "err", err) -// } -// return -// } -// -// if typ != websocket.TextMessage && typ != websocket.BinaryMessage { -// continue -// } -// -// var msg GatewayModMessageIn -// if err := json.Unmarshal(data, &msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) -// wsg.logger.Warn("invalid json from mod", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// msg.ReceivedAt = time.Now().UTC() -// if err := msg.Validate(); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) -// wsg.logger.Warn("mod message validation failed", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// // Handle the message (forward to bot, enrich, etc.) -// if err := wsg.modHandler.Handle(conn, msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) -// wsg.logger.Error("mod handler error", "server_id", meta.ID, "err", err) -// continue -// } -// -// _ = writeJSONSafe(conn, map[string]string{"status": "completed"}) // or "queued" -// } -//} -// -//func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { -// defer func() { -// wsg.unregisterConn(conn, meta, "bot") -// wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) -// }() -// -// pingTicker := time.NewTicker(30 * time.Second) -// defer pingTicker.Stop() -// -// // Send pings in a separate goroutine -// go func() { -// for range pingTicker.C { -// _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) -// if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { -// wsg.logger.Debug("write ping failed", "bot_id", meta.ID, "err", err) -// return -// } -// wsg.logger.Debug("sent ping to bot", "bot_id", meta.ID) -// } -// }() -// -// for { -// typ, data, err := conn.ReadMessage() -// if err != nil { -// if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { -// wsg.logger.Warn("unexpected bot close", "bot_id", meta.ID, "err", err) -// } else { -// wsg.logger.Debug("bot read error", "bot_id", meta.ID, "err", err) -// } -// return -// } -// -// if typ != websocket.TextMessage && typ != websocket.BinaryMessage { -// continue -// } -// -// var msg GatewayBotMessageIn -// if err := json.Unmarshal(data, &msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) -// wsg.logger.Warn("invalid json from bot", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// msg.ReceivedAt = time.Now().UTC() -// if err := msg.Validate(); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) -// wsg.logger.Warn("bot message validation failed", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// // Handle the message (forward to mod, enrich, etc.) -// if err := wsg.botHandler.Handle(conn, msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) -// wsg.logger.Error("bot handler error", "bot_id", meta.ID, "err", err) -// continue -// } -// -// _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) -// } -//} - func (wsg *WebsocketGateway) read(conn *websocket.Conn, _type, channelId string) { defer func() { wsg.unregisterConn(_type, channelId)