156 lines
4.1 KiB
Go
156 lines
4.1 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"homestead/homestead_gateway/util/cache"
|
|
"log/slog"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// (de-)register
|
|
|
|
func (wsg *WebsocketGateway) listen(srv *http.Server, addr string, channel chan error) {
|
|
wsg.logger.Info("Gateway listening.", "addr", addr)
|
|
|
|
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
channel <- err
|
|
}
|
|
|
|
close(channel)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) deafen(srv *http.Server) {
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
wsg.logger.Info("Shutting down Websocket Gateway.")
|
|
|
|
_ = srv.Shutdown(shutdownCtx)
|
|
wsg.closeAll()
|
|
}
|
|
|
|
// responses
|
|
|
|
func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code})
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
_ = conn.WriteMessage(websocket.PingMessage, nil)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int, close bool) {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
_ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code})
|
|
if close {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
|
|
if err := conn.WriteJSON(content); err != nil {
|
|
wsg.logger.Error("Failed to respond to connection.", "remote", conn.RemoteAddr().String(), "err", err)
|
|
_ = conn.Close()
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// util
|
|
|
|
func (wsg *WebsocketGateway) validateAndUpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
|
if !wsg.validateApiKey(r) {
|
|
wsg.sendHttpError(w, "Unauthorized", 401)
|
|
wsg.logger.Warn("Authorization failed", "remote", r.RemoteAddr)
|
|
return nil, errors.New("unauthorized")
|
|
}
|
|
|
|
conn, err := wsg.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
wsg.logger.Error("Upgrade error.", "remote", r.RemoteAddr)
|
|
return nil, err
|
|
}
|
|
|
|
return conn, err
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool {
|
|
apiKey := r.URL.Query().Get("api_key")
|
|
if apiKey == "" {
|
|
apiKey = r.Header.Get("X-API-Key")
|
|
}
|
|
|
|
return !(apiKey == "" || apiKey != wsg.apiKey)
|
|
}
|
|
|
|
func writeJSONSafe(c *websocket.Conn, v interface{}) error {
|
|
_ = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
if err := c.WriteJSON(v); err != nil {
|
|
// caller handles logging
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start))
|
|
})
|
|
}
|
|
|
|
// connections
|
|
|
|
func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) bool {
|
|
if typ == "bot" {
|
|
if wsg.bot != nil {
|
|
return false
|
|
}
|
|
|
|
wsg.bot = conn
|
|
return true
|
|
}
|
|
|
|
wsg.conns.Set(meta.ID, conn, &meta)
|
|
return true
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) {
|
|
if typ == "bot" {
|
|
_ = wsg.bot.Close()
|
|
wsg.bot = nil
|
|
return
|
|
}
|
|
|
|
wsg.conns.RemoveById(meta.ID)
|
|
_ = conn.Close()
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) closeAll() {
|
|
wsg.logger.Info("Closing all websocket connections.")
|
|
|
|
if wsg.bot != nil {
|
|
_ = wsg.bot.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second))
|
|
_ = wsg.bot.Close()
|
|
}
|
|
|
|
wsg.conns.Range(func(id string, c *websocket.Conn, meta *cache.ConnectionMetaData) {
|
|
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second))
|
|
_ = c.Close()
|
|
})
|
|
|
|
wsg.conns.Clear()
|
|
}
|