93 lines
2.1 KiB
Go
93 lines
2.1 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
func (g *WebsocketGateway) StartGatewayWithForwarder() {
|
|
go func() {
|
|
for out := range g.Outgoing() {
|
|
b, _ := json.Marshal(out)
|
|
// use Info because these are normal forward events
|
|
g.logger.Info("forwarder -> bot", "msg", string(b))
|
|
}
|
|
}()
|
|
|
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
|
defer stop()
|
|
|
|
if err := g.Serve(ctx, fmt.Sprintf(":%d", g.port)); err != nil {
|
|
g.logger.Error("gateway serve error", err)
|
|
}
|
|
close(g.outgoingCh)
|
|
}
|
|
|
|
func (g *WebsocketGateway) Outgoing() <-chan GatewayMessageOut {
|
|
return g.outgoingCh
|
|
}
|
|
|
|
func (g *WebsocketGateway) OutgoingLen() int {
|
|
return len(g.outgoingCh)
|
|
}
|
|
|
|
// util
|
|
|
|
func (g *WebsocketGateway) validateApiKey(r *http.Request) bool {
|
|
apiKey := r.URL.Query().Get("api_key")
|
|
if apiKey == "" {
|
|
apiKey = r.Header.Get("X-API-Key")
|
|
}
|
|
|
|
return !(apiKey == "" || apiKey != g.apiKey)
|
|
}
|
|
|
|
func writeJSONSafe(c *websocket.Conn, v interface{}) error {
|
|
_ = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
if err := c.WriteJSON(v); err != nil {
|
|
// caller handles logging
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
logger.Info("http request", "remote", r.RemoteAddr, "method", r.Method, "path", r.URL.Path, "duration", time.Since(start))
|
|
})
|
|
}
|
|
|
|
// connections
|
|
|
|
func (g *WebsocketGateway) registerConn(c *websocket.Conn) {
|
|
g.modConnsMu.Lock()
|
|
g.modConns[c] = struct{}{}
|
|
g.modConnsMu.Unlock()
|
|
}
|
|
|
|
func (g *WebsocketGateway) unregisterConn(c *websocket.Conn) {
|
|
g.modConnsMu.Lock()
|
|
delete(g.modConns, c)
|
|
g.modConnsMu.Unlock()
|
|
}
|
|
|
|
func (g *WebsocketGateway) closeAll() {
|
|
g.modConnsMu.Lock()
|
|
defer g.modConnsMu.Unlock()
|
|
g.logger.Info("closing websocket connections")
|
|
for c := range g.modConns {
|
|
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"), time.Now().Add(time.Second))
|
|
_ = c.Close()
|
|
}
|
|
}
|