package ws import ( "context" "encoding/json" "errors" "log/slog" "net/http" "time" "github.com/gorilla/websocket" ) // (de-)register func (wsg *WebsocketGateway) listen(srv *http.Server, addr string, channel chan error) { wsg.logger.Info("Gateway listening.", "addr", addr) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { channel <- err } close(channel) } func (wsg *WebsocketGateway) deafen(srv *http.Server) { shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wsg.logger.Info("Shutting down Websocket Gateway.") _ = srv.Shutdown(shutdownCtx) wsg.closeAll() } // responses func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) _ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code}) } func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int) { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code}) _ = conn.Close() } func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := conn.WriteJSON(content); err != nil { wsg.logger.Error("Failed to respond to connection.", "remote", conn.RemoteAddr().String(), "err", err) _ = conn.Close() return err } return nil } // util func (wsg *WebsocketGateway) validateAndUpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { if !wsg.validateApiKey(r) { wsg.sendHttpError(w, "Unauthorized", 401) wsg.logger.Warn("Authorization failed", "remote", r.RemoteAddr) return nil, errors.New("unauthorized") } conn, err := wsg.upgrader.Upgrade(w, r, nil) if err != nil { wsg.logger.Error("Upgrade error.", "remote", r.RemoteAddr) return nil, err } return conn, err } func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { apiKey := r.URL.Query().Get("api_key") if apiKey == "" { apiKey = r.Header.Get("X-API-Key") } return !(apiKey == "" || apiKey != wsg.apiKey) } func writeJSONSafe(c *websocket.Conn, v interface{}) error { _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := c.WriteJSON(v); err != nil { // caller handles logging return err } return nil } func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) logger.Info("http request", "remote", r.RemoteAddr, "method", r.Method, "path", r.URL.Path, "duration", time.Since(start)) }) } // connections func (wsg *WebsocketGateway) registerConn(c *websocket.Conn, meta connectionMetaData) { wsg.connsMu.Lock() wsg.conns[c] = meta wsg.connsMu.Unlock() } func (wsg *WebsocketGateway) unregisterConn(c *websocket.Conn) { wsg.connsMu.Lock() delete(wsg.conns, c) wsg.connsMu.Unlock() } func (wsg *WebsocketGateway) closeAll() { wsg.connsMu.Lock() defer wsg.connsMu.Unlock() wsg.logger.Info("Closing all websocket connections.") for c := range wsg.conns { _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) _ = c.Close() } }