Files
HomesteadGateway/ws/util.go

205 lines
5.1 KiB
Go

package ws
import (
"context"
"encoding/json"
"errors"
"homestead/homestead_gateway/util"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
)
// (de-)register
func (wsg *WebsocketGateway) listen(srv *http.Server, addr string, channel chan error) {
wsg.logger.Info("Gateway listening.", "addr", addr)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
channel <- err
}
close(channel)
}
func (wsg *WebsocketGateway) deafen(srv *http.Server) {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wsg.logger.Info("Shutting down Websocket Gateway.")
_ = srv.Shutdown(shutdownCtx)
wsg.closeAll()
}
// responses
func (wsg *WebsocketGateway) flush(c *websocket.Conn, m GatewayMessageOut) error {
_ = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
return c.WriteJSON(m)
}
func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code})
}
func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) {
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
_ = conn.WriteMessage(websocket.PingMessage, nil)
}
func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int, close bool) {
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
_ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code})
if close {
util.CloseConn(conn)
}
}
func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error {
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := conn.WriteJSON(content); err != nil {
wsg.logger.Error("Failed to respond to connection.", "remote", conn.RemoteAddr().String(), "err", err)
util.CloseConnWithControlMessage(conn, websocket.CloseAbnormalClosure, "Connection error.")
return err
}
return nil
}
// util
func (wsg *WebsocketGateway) validateAndUpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
if !wsg.validateApiKey(r) {
wsg.sendHttpError(w, "Unauthorized", 401)
wsg.logger.Warn("Authorization failed", "remote", r.RemoteAddr)
return nil, errors.New("unauthorized")
}
conn, err := wsg.upgrader.Upgrade(w, r, nil)
if err != nil {
wsg.logger.Error("Upgrade error.", "remote", r.RemoteAddr)
return nil, err
}
return conn, err
}
func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool {
apiKey := r.URL.Query().Get("api_key")
if apiKey == "" {
apiKey = r.Header.Get("X-API-Key")
}
return !(apiKey == "" || apiKey != wsg.apiKey)
}
func (wsg *WebsocketGateway) loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
wsg.logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start))
})
}
// connections
func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool {
if typ == "bot" {
wsg.registry.botMu.Lock()
if wsg.registry.bot != nil && wsg.registry.bot.Conn != nil {
wsg.registry.botMu.Unlock()
return false
}
wsg.registry.botMu.Unlock()
wsg.registry.RegisterBot(conn)
return true
}
wsg.registry.RegisterMod(channelId, serverId, conn)
return true
}
func (wsg *WebsocketGateway) unregisterConn(typ, channelId string) {
if typ == "bot" {
wsg.registry.UnregisterBot()
return
}
wsg.registry.UnregisterMod(channelId)
}
func (wsg *WebsocketGateway) closeAll() {
wsg.logger.Info("Closing all websocket connections.")
wsg.registry.UnregisterBot()
wsg.registry.ForEach(func(channelID string) {
wsg.registry.UnregisterMod(channelID)
})
}
//
func NewUpgrader() websocket.Upgrader {
return websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // local by default; change for production
},
}
}
func NewRegistry(queueCap int) *Registry {
return &Registry{
entries: make(map[string]*ChannelEntry),
queueCap: queueCap,
}
}
func NewBoundedQueue(cap int) *BoundedQueue {
if cap <= 0 {
cap = 128
}
return &BoundedQueue{buf: make([]GatewayMessageOut, cap), capacity: cap}
}
func newChannelEntry(channel string, cap int) *ChannelEntry {
return &ChannelEntry{
Channel: channel,
Queue: NewBoundedQueue(cap),
}
}
//
func (m *GatewayMessageIn) Validate() error {
if strings.TrimSpace(m.ID) == "" {
return errors.New("id missing")
}
if strings.TrimSpace(m.MsgID) == "" {
return errors.New("msg_id missing")
}
if strings.TrimSpace(m.Author.ID) == "" {
return errors.New("author.id missing")
}
if strings.TrimSpace(m.Content) == "" {
return errors.New("content missing")
}
if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" {
return errors.New("destination.channel_id missing")
}
return nil
}
func (c *ConnWrapper) Alive() bool { return c != nil && c.Conn != nil }