214 lines
5.2 KiB
Go
214 lines
5.2 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// (de-)register
|
|
|
|
func (wsg *WebsocketGateway) listen(srv *http.Server, addr string, channel chan error) {
|
|
wsg.logger.Info("Gateway listening.", "addr", addr)
|
|
|
|
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
channel <- err
|
|
}
|
|
|
|
close(channel)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) deafen(srv *http.Server) {
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
wsg.logger.Info("Shutting down Websocket Gateway.")
|
|
|
|
_ = srv.Shutdown(shutdownCtx)
|
|
wsg.closeAll()
|
|
}
|
|
|
|
// responses
|
|
|
|
func (wsg *WebsocketGateway) flush(c *websocket.Conn, m GatewayMessageOut) error {
|
|
_ = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
return c.WriteJSON(m)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code})
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
_ = conn.WriteMessage(websocket.PingMessage, nil)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int, close bool) {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
_ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code})
|
|
if close {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
|
|
if err := conn.WriteJSON(content); err != nil {
|
|
wsg.logger.Error("Failed to respond to connection.", "remote", conn.RemoteAddr().String(), "err", err)
|
|
_ = conn.Close()
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// util
|
|
|
|
func (wsg *WebsocketGateway) validateAndUpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
|
if !wsg.validateApiKey(r) {
|
|
wsg.sendHttpError(w, "Unauthorized", 401)
|
|
wsg.logger.Warn("Authorization failed", "remote", r.RemoteAddr)
|
|
return nil, errors.New("unauthorized")
|
|
}
|
|
|
|
conn, err := wsg.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
wsg.logger.Error("Upgrade error.", "remote", r.RemoteAddr)
|
|
return nil, err
|
|
}
|
|
|
|
return conn, err
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool {
|
|
apiKey := r.URL.Query().Get("api_key")
|
|
if apiKey == "" {
|
|
apiKey = r.Header.Get("X-API-Key")
|
|
}
|
|
|
|
return !(apiKey == "" || apiKey != wsg.apiKey)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) loggingMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
wsg.logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start))
|
|
})
|
|
}
|
|
|
|
// connections
|
|
|
|
func closeConn(conn *websocket.Conn) {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
|
_ = conn.WriteControl(
|
|
websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."),
|
|
time.Now().Add(time.Second),
|
|
)
|
|
_ = conn.Close()
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool {
|
|
if typ == "bot" {
|
|
wsg.registry.botMu.Lock()
|
|
if wsg.registry.bot != nil && wsg.registry.bot.Conn != nil {
|
|
wsg.registry.botMu.Unlock()
|
|
return false
|
|
}
|
|
wsg.registry.botMu.Unlock()
|
|
|
|
wsg.registry.RegisterBot(conn)
|
|
return true
|
|
}
|
|
|
|
wsg.registry.RegisterMod(channelId, serverId, conn)
|
|
return true
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) unregisterConn(typ, channelId string) {
|
|
if typ == "bot" {
|
|
wsg.registry.UnregisterBot()
|
|
return
|
|
}
|
|
|
|
wsg.registry.UnregisterMod(channelId)
|
|
}
|
|
|
|
func (wsg *WebsocketGateway) closeAll() {
|
|
wsg.logger.Info("Closing all websocket connections.")
|
|
|
|
wsg.registry.UnregisterBot()
|
|
|
|
wsg.registry.ForEach(func(channelID string) {
|
|
wsg.registry.UnregisterMod(channelID)
|
|
})
|
|
}
|
|
|
|
//
|
|
|
|
func NewUpgrader() websocket.Upgrader {
|
|
return websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // local by default; change for production
|
|
},
|
|
}
|
|
}
|
|
|
|
func NewRegistry(queueCap int) *Registry {
|
|
return &Registry{
|
|
entries: make(map[string]*ChannelEntry),
|
|
queueCap: queueCap,
|
|
}
|
|
}
|
|
|
|
func NewBoundedQueue(cap int) *BoundedQueue {
|
|
if cap <= 0 {
|
|
cap = 128
|
|
}
|
|
return &BoundedQueue{buf: make([]GatewayMessageOut, cap), capacity: cap}
|
|
}
|
|
|
|
func newChannelEntry(channel string, cap int) *ChannelEntry {
|
|
return &ChannelEntry{
|
|
Channel: channel,
|
|
Queue: NewBoundedQueue(cap),
|
|
}
|
|
}
|
|
|
|
//
|
|
|
|
func (m *GatewayMessageIn) Validate() error {
|
|
if strings.TrimSpace(m.ID) == "" {
|
|
return errors.New("id missing")
|
|
}
|
|
if strings.TrimSpace(m.MsgID) == "" {
|
|
return errors.New("msg_id missing")
|
|
}
|
|
if strings.TrimSpace(m.Author.ID) == "" {
|
|
return errors.New("author.id missing")
|
|
}
|
|
if strings.TrimSpace(m.Content) == "" {
|
|
return errors.New("content missing")
|
|
}
|
|
|
|
if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" {
|
|
return errors.New("destination.channel_id missing")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *ConnWrapper) Alive() bool { return c != nil && c.Conn != nil }
|