package ws import ( "context" "encoding/json" "errors" "homestead/homestead_gateway/util" "net/http" "strings" "time" "github.com/gorilla/websocket" ) // (de-)register func (wsg *WebsocketGateway) listen(srv *http.Server, addr string, channel chan error) { wsg.logger.Info("Gateway listening.", "addr", addr) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { channel <- err } close(channel) } func (wsg *WebsocketGateway) deafen(srv *http.Server) { shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wsg.logger.Info("Shutting down Websocket Gateway.") _ = srv.Shutdown(shutdownCtx) wsg.closeAll() } // responses func (wsg *WebsocketGateway) flush(c *websocket.Conn, m GatewayMessageOut) error { _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) return c.WriteJSON(m) } func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) _ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code}) } func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _ = conn.WriteMessage(websocket.PingMessage, nil) } func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int, close bool) { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code}) if close { util.CloseConn(conn) } } func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := conn.WriteJSON(content); err != nil { wsg.logger.Error("Failed to respond to connection.", "remote", conn.RemoteAddr().String(), "err", err) util.CloseConnWithControlMessage(conn, websocket.CloseAbnormalClosure, "Connection error.") return err } return nil } // util func (wsg *WebsocketGateway) validateAndUpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { if !wsg.validateApiKey(r) { wsg.sendHttpError(w, "Unauthorized", 401) wsg.logger.Warn("Authorization failed", "remote", r.RemoteAddr) return nil, errors.New("unauthorized") } conn, err := wsg.upgrader.Upgrade(w, r, nil) if err != nil { wsg.logger.Error("Upgrade error.", "remote", r.RemoteAddr) return nil, err } return conn, err } func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { apiKey := r.URL.Query().Get("api_key") if apiKey == "" { apiKey = r.Header.Get("X-API-Key") } return !(apiKey == "" || apiKey != wsg.apiKey) } func (wsg *WebsocketGateway) loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) wsg.logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start)) }) } // connections func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool { if typ == "bot" { wsg.registry.botMu.Lock() if wsg.registry.bot != nil && wsg.registry.bot.Conn != nil { wsg.registry.botMu.Unlock() return false } wsg.registry.botMu.Unlock() wsg.registry.RegisterBot(conn) return true } wsg.registry.RegisterMod(channelId, serverId, conn) return true } func (wsg *WebsocketGateway) unregisterConn(typ, channelId string) { if typ == "bot" { wsg.registry.UnregisterBot() return } wsg.registry.UnregisterMod(channelId) } func (wsg *WebsocketGateway) closeAll() { wsg.logger.Info("Closing all websocket connections.") wsg.registry.UnregisterBot() wsg.registry.ForEach(func(channelID string) { wsg.registry.UnregisterMod(channelID) }) } // func NewUpgrader() websocket.Upgrader { return websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // local by default; change for production }, } } func NewRegistry(queueCap int) *Registry { return &Registry{ entries: make(map[string]*ChannelEntry), queueCap: queueCap, } } func NewBoundedQueue(cap int) *BoundedQueue { if cap <= 0 { cap = 128 } return &BoundedQueue{buf: make([]GatewayMessageOut, cap), capacity: cap} } func newChannelEntry(channel string, cap int) *ChannelEntry { return &ChannelEntry{ Channel: channel, Queue: NewBoundedQueue(cap), } } // func (m *GatewayMessageIn) Validate() error { if strings.TrimSpace(m.ID) == "" { return errors.New("id missing") } if strings.TrimSpace(m.MsgID) == "" { return errors.New("msg_id missing") } if strings.TrimSpace(m.Author.ID) == "" { return errors.New("author.id missing") } if strings.TrimSpace(m.Content) == "" { return errors.New("content missing") } if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" { return errors.New("destination.channel_id missing") } return nil } func (c *ConnWrapper) Alive() bool { return c != nil && c.Conn != nil }