From 667db6be83dd07e4c4f7dab0559a0cd8b4709795 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 10:45:58 +0100 Subject: [PATCH] working websocket implementation --- .gitignore | 1 + config.toml | 8 +- go.mod | 1 + go.sum | 2 + main.go | 22 +---- shared/controller.go | 27 ++++++ shared/structs.go | 12 +++ util/config/structs.go | 4 +- ws/structs.go | 54 ++++++++++++ ws/util.go | 92 ++++++++++++++++++++ ws/websocket.go | 186 +++++++++++++++++++++++++++++++++++++++++ 11 files changed, 386 insertions(+), 23 deletions(-) create mode 100644 shared/controller.go create mode 100644 shared/structs.go create mode 100644 ws/structs.go create mode 100644 ws/util.go create mode 100644 ws/websocket.go diff --git a/.gitignore b/.gitignore index c480041..bf9d731 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ go.work.sum .idea /.dea/ +/logs/ diff --git a/config.toml b/config.toml index c9db740..f0eacef 100644 --- a/config.toml +++ b/config.toml @@ -1,11 +1,13 @@ [log] level = "info" directory = "logs/" -rotation = 3 # in days +rotation = 3 # in days [gateway] -http_port = "" -websocket = "" +http_port = 8080 +websocket = "gateway" +body_size = 2 # in MB +queue_max = 8192 [database] host_dsn = "" diff --git a/go.mod b/go.mod index f2b8ec3..086d25e 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( ) require ( + github.com/gorilla/websocket v1.5.3 // indirect github.com/samber/lo v1.52.0 // indirect github.com/samber/slog-common v0.19.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index b3a66c9..41997cc 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= diff --git a/main.go b/main.go index 116a53b..f0f544f 100644 --- a/main.go +++ b/main.go @@ -2,33 +2,17 @@ package main import ( "flag" - "fmt" + "homestead/homestead_gateway/shared" "homestead/homestead_gateway/util/config" - "homestead/homestead_gateway/util/logger" - "log/slog" - "os" ) func main() { cfgPath := flag.String("config", "config.toml", "configuration file") - cfg, err := config.LoadConfig(*cfgPath) if err != nil { panic(err) } - l, closeFn, err := logger.New("my-service", cfg.Log) - if err != nil { - panic(err) - } - - defer func() { - if err := closeFn(); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "error closing logs: %v\n", err) - } - }() - - slog.SetDefault(l) - l.Info("started", "port", 8080) - l.Error("something broke", "err", err) + controller := shared.NewGatewayController(*cfg) + controller.Run() } diff --git a/shared/controller.go b/shared/controller.go new file mode 100644 index 0000000..b27ecdb --- /dev/null +++ b/shared/controller.go @@ -0,0 +1,27 @@ +package shared + +import ( + "homestead/homestead_gateway/util/config" + "homestead/homestead_gateway/util/logger" + "homestead/homestead_gateway/ws" +) + +func NewGatewayController(cfg config.Config) GatewayController { + wsl, wCloseFn, err := logger.New("Websocket", cfg.Log) + if err != nil { + panic(err) + } + //hsl, hCloseFn, err := logger.New("HttpServer", cfg.Log) + //if err != nil { + // panic(err) + //} + + return GatewayController{ + Websocket: ws.NewWsGateway(cfg.Gateway, wsl, wCloseFn), + HttpServer: HttpGateway{}, + } +} + +func (gc *GatewayController) Run() { + gc.Websocket.StartGatewayWithForwarder() +} diff --git a/shared/structs.go b/shared/structs.go new file mode 100644 index 0000000..26357fb --- /dev/null +++ b/shared/structs.go @@ -0,0 +1,12 @@ +package shared + +import ( + "homestead/homestead_gateway/ws" +) + +type GatewayController struct { + Websocket *ws.WebsocketGateway + HttpServer HttpGateway +} + +type HttpGateway struct{} diff --git a/util/config/structs.go b/util/config/structs.go index 0491f08..4700c84 100644 --- a/util/config/structs.go +++ b/util/config/structs.go @@ -9,8 +9,10 @@ type Config struct { } type GatewayConfig struct { - HttpPort string `toml:"http_port"` + HttpPort int `toml:"http_port"` Websocket string `toml:"websocket"` + BodySize int `toml:"body_size"` + QueueSize int `toml:"queue_max"` } type LogConfig struct { diff --git a/ws/structs.go b/ws/structs.go new file mode 100644 index 0000000..0669da2 --- /dev/null +++ b/ws/structs.go @@ -0,0 +1,54 @@ +package ws + +import ( + "log/slog" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type WebsocketGateway struct { + port int + apiKey string + bodySizeBytes int64 + + upgrader websocket.Upgrader + modConnsMu sync.Mutex + modConns map[*websocket.Conn]struct{} + + botConnsMu sync.Mutex + botConns map[*websocket.Conn]*sync.Mutex + outgoingCh chan GatewayMessageOut + + logger *slog.Logger + closefn func() error +} + +type User struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type GatewayMessageIn struct { + MsgID string `json:"msg_id"` + Server string `json:"server"` + User User `json:"user"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) +} + +type GatewayMessageOut struct { + Type string `json:"type"` // "message" + Payload GatewayMessageIn `json:"payload"` + ForwardedBy string `json:"forwarded_by"` // "ws"/"http" + ForwardedAt time.Time `json:"forwarded_at"` +} + +type BotAck struct { + Type string `json:"type"` // "acknowledge" + MsgID string `json:"msg_id"` + Status string `json:"status,omitempty"` // "queued"/"sent" +} diff --git a/ws/util.go b/ws/util.go new file mode 100644 index 0000000..904932d --- /dev/null +++ b/ws/util.go @@ -0,0 +1,92 @@ +package ws + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os/signal" + "syscall" + "time" + + "github.com/gorilla/websocket" +) + +func (g *WebsocketGateway) StartGatewayWithForwarder() { + go func() { + for out := range g.Outgoing() { + b, _ := json.Marshal(out) + // use Info because these are normal forward events + g.logger.Info("forwarder -> bot", "msg", string(b)) + } + }() + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + if err := g.Serve(ctx, fmt.Sprintf(":%d", g.port)); err != nil { + g.logger.Error("gateway serve error", err) + } + close(g.outgoingCh) +} + +func (g *WebsocketGateway) Outgoing() <-chan GatewayMessageOut { + return g.outgoingCh +} + +func (g *WebsocketGateway) OutgoingLen() int { + return len(g.outgoingCh) +} + +// util + +func (g *WebsocketGateway) validateApiKey(r *http.Request) bool { + apiKey := r.URL.Query().Get("api_key") + if apiKey == "" { + apiKey = r.Header.Get("X-API-Key") + } + + return !(apiKey == "" || apiKey != g.apiKey) +} + +func writeJSONSafe(c *websocket.Conn, v interface{}) error { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteJSON(v); err != nil { + // caller handles logging + return err + } + return nil +} + +func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + logger.Info("http request", "remote", r.RemoteAddr, "method", r.Method, "path", r.URL.Path, "duration", time.Since(start)) + }) +} + +// connections + +func (g *WebsocketGateway) registerConn(c *websocket.Conn) { + g.modConnsMu.Lock() + g.modConns[c] = struct{}{} + g.modConnsMu.Unlock() +} + +func (g *WebsocketGateway) unregisterConn(c *websocket.Conn) { + g.modConnsMu.Lock() + delete(g.modConns, c) + g.modConnsMu.Unlock() +} + +func (g *WebsocketGateway) closeAll() { + g.modConnsMu.Lock() + defer g.modConnsMu.Unlock() + g.logger.Info("closing websocket connections") + for c := range g.modConns { + _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"), time.Now().Add(time.Second)) + _ = c.Close() + } +} diff --git a/ws/websocket.go b/ws/websocket.go new file mode 100644 index 0000000..0797a12 --- /dev/null +++ b/ws/websocket.go @@ -0,0 +1,186 @@ +package ws + +import ( + "context" + "encoding/json" + "errors" + "homestead/homestead_gateway/util/config" + "log/slog" + "net" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" +) + +func (m *GatewayMessageIn) Validate() error { + if strings.TrimSpace(m.MsgID) == "" { + return errors.New("msg_id missing") + } + if strings.TrimSpace(m.Server) == "" { + return errors.New("server missing") + } + if strings.TrimSpace(m.User.ID) == "" { + return errors.New("user.mod_uid missing") + } + if strings.TrimSpace(m.Content) == "" { + return errors.New("content missing") + } + return nil +} + +func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { + return &WebsocketGateway{ + logger: logger, + closefn: closefn, + apiKey: cfg.Websocket, + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // local by default; change for production + }, + }, + outgoingCh: make(chan GatewayMessageOut, cfg.QueueSize), + modConns: make(map[*websocket.Conn]struct{}), + bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, + port: cfg.HttpPort, + } +} + +// Serve starts the HTTP server and /ws endpoint and blocks until ctx cancelled or server fails. +func (g *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { + mux := http.NewServeMux() + mux.HandleFunc("/ws", g.handleWS) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, _ = w.Write([]byte("ok")) + }) + + srv := &http.Server{ + Addr: listenAddr, + Handler: loggingMiddleware(g.logger, mux), + BaseContext: func(l net.Listener) context.Context { return ctx }, + } + + errCh := make(chan error, 1) + go func() { + g.logger.Info("ws gateway listening", "addr", listenAddr) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errCh <- err + } + close(errCh) + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + g.logger.Info("shutting down http server") + _ = srv.Shutdown(shutdownCtx) + g.closeAll() + return nil + case err := <-errCh: + return err + } +} + +func (g *WebsocketGateway) handleWS(w http.ResponseWriter, r *http.Request) { + if !g.validateApiKey(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + g.logger.Warn("ws auth failed", "remote", r.RemoteAddr) + return + } + + conn, err := g.upgrader.Upgrade(w, r, nil) + if err != nil { + g.logger.Error("ws upgrade error", err, "remote", r.RemoteAddr) + return + } + + g.registerConn(conn) + g.logger.Info("ws connected", "remote", conn.RemoteAddr().String()) + + // Configure read limits & pong handler + if g.bodySizeBytes > 0 { + conn.SetReadLimit(g.bodySizeBytes) + } else { + conn.SetReadLimit(1 << 20) // sensible default 1MiB + } + + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + go g.readLoop(conn) +} + +func (g *WebsocketGateway) readLoop(c *websocket.Conn) { + defer func() { + g.unregisterConn(c) + _ = c.Close() + g.logger.Info("ws disconnected", "remote", c.RemoteAddr().String()) + }() + + pingTicker := time.NewTicker(30 * time.Second) + defer pingTicker.Stop() + + for { + // Read one message (blocks until message arrives) + typ, data, err := c.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + g.logger.Warn("unexpected ws close", "err", err) + } else { + g.logger.Debug("ws read error", "err", err) + } + return + } + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { + continue + } + + var in GatewayMessageIn + if err := json.Unmarshal(data, &in); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) + g.logger.Warn("invalid json from client", "remote", c.RemoteAddr().String(), "err", err) + continue + } + in.ReceivedAt = time.Now().UTC() + if err := in.Validate(); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) + g.logger.Warn("message validation failed", "remote", c.RemoteAddr().String(), "err", err) + continue + } + + out := GatewayMessageOut{ + Type: "message", + Payload: in, + ForwardedAt: time.Now().UTC(), + } + + // Non-blocking enqueue with backpressure + select { + case g.outgoingCh <- out: + _ = writeJSONSafe(c, map[string]string{"status": "queued"}) + g.logger.Debug("enqueued message", "msg_id", in.MsgID, "server", in.Server) + default: + _ = writeJSONSafe(c, map[string]string{"error": "gateway busy"}) + g.logger.Warn("outgoing queue full", "msg_id", in.MsgID, "remote", c.RemoteAddr().String()) + } + + // also handle pings periodically (so client sees ping frequently) + select { + case <-pingTicker.C: + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { + g.logger.Debug("write ping failed", "err", err) + return + } + default: + } + } +}