From 675e11427881506b012110441e3d2a02ca70ff67 Mon Sep 17 00:00:00 2001 From: Overlord Date: Sun, 30 Nov 2025 22:00:47 +0100 Subject: [PATCH 01/15] init, eod commit --- .gitignore | 18 +--- config.toml | 14 +++ go.mod | 14 +++ go.sum | 10 +++ main.go | 34 ++++++++ util/config/structs.go | 27 ++++++ util/config/toml.go | 21 +++++ util/logger/log.go | 191 +++++++++++++++++++++++++++++++++++++++++ util/logger/structs.go | 27 ++++++ 9 files changed, 340 insertions(+), 16 deletions(-) create mode 100644 config.toml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 util/config/structs.go create mode 100644 util/config/toml.go create mode 100644 util/logger/log.go create mode 100644 util/logger/structs.go diff --git a/.gitignore b/.gitignore index 5b90e79..c480041 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,13 @@ -# ---> Go -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib - -# Test binary, built with `go test -c` *.test - -# Output of the go coverage tool, specifically when used with LiteIDE *.out - -# Dependency directories (remove the comment below to include it) -# vendor/ - -# Go workspace file go.work go.work.sum - -# env file .env +.idea +/.dea/ diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..c9db740 --- /dev/null +++ b/config.toml @@ -0,0 +1,14 @@ +[log] +level = "info" +directory = "logs/" +rotation = 3 # in days + +[gateway] +http_port = "" +websocket = "" + +[database] +host_dsn = "" +username = "" +password = "" +database = "" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f2b8ec3 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module homestead/homestead_gateway + +go 1.25.4 + +require ( + github.com/pelletier/go-toml/v2 v2.2.4 + github.com/samber/slog-multi v1.6.0 +) + +require ( + github.com/samber/lo v1.52.0 // indirect + github.com/samber/slog-common v0.19.0 // indirect + golang.org/x/text v0.22.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b3a66c9 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= +github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= +github.com/samber/slog-common v0.19.0 h1:fNcZb8B2uOLooeYwFpAlKjkQTUafdjfqKcwcC89G9YI= +github.com/samber/slog-common v0.19.0/go.mod h1:dTz+YOU76aH007YUU0DffsXNsGFQRQllPQh9XyNoA3M= +github.com/samber/slog-multi v1.6.0 h1:i1uBY+aaln6ljwdf7Nrt4Sys8Kk6htuYuXDHWJsHtZg= +github.com/samber/slog-multi v1.6.0/go.mod h1:qTqzmKdPpT0h4PFsTN5rYRgLwom1v+fNGuIrl1Xnnts= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= diff --git a/main.go b/main.go new file mode 100644 index 0000000..116a53b --- /dev/null +++ b/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "flag" + "fmt" + "homestead/homestead_gateway/util/config" + "homestead/homestead_gateway/util/logger" + "log/slog" + "os" +) + +func main() { + cfgPath := flag.String("config", "config.toml", "configuration file") + + cfg, err := config.LoadConfig(*cfgPath) + if err != nil { + panic(err) + } + + l, closeFn, err := logger.New("my-service", cfg.Log) + if err != nil { + panic(err) + } + + defer func() { + if err := closeFn(); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error closing logs: %v\n", err) + } + }() + + slog.SetDefault(l) + l.Info("started", "port", 8080) + l.Error("something broke", "err", err) +} diff --git a/util/config/structs.go b/util/config/structs.go new file mode 100644 index 0000000..0491f08 --- /dev/null +++ b/util/config/structs.go @@ -0,0 +1,27 @@ +package config + +import "log/slog" + +type Config struct { + Log LogConfig `toml:"log"` + Gateway GatewayConfig `toml:"gateway"` + Database DatabaseConfig `toml:"database"` +} + +type GatewayConfig struct { + HttpPort string `toml:"http_port"` + Websocket string `toml:"websocket"` +} + +type LogConfig struct { + Level slog.Level `toml:"level"` + Directory string `toml:"directory"` + Rotation int `toml:"rotation"` +} + +type DatabaseConfig struct { + HostDSN string `toml:"host_dsn"` + Username string `toml:"username"` + Password string `toml:"password"` + Database string `toml:"database"` +} diff --git a/util/config/toml.go b/util/config/toml.go new file mode 100644 index 0000000..a3351bf --- /dev/null +++ b/util/config/toml.go @@ -0,0 +1,21 @@ +package config + +import ( + "fmt" + "os" + + "github.com/pelletier/go-toml/v2" +) + +func LoadConfig(path string) (*Config, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open config: %w", err) + } + + var cfg Config + if err = toml.NewDecoder(file).Decode(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} diff --git a/util/logger/log.go b/util/logger/log.go new file mode 100644 index 0000000..72f7a71 --- /dev/null +++ b/util/logger/log.go @@ -0,0 +1,191 @@ +package logger + +import ( + "context" + "fmt" + "homestead/homestead_gateway/util/config" + "log/slog" + "os" + "path/filepath" + "time" + + slogmulti "github.com/samber/slog-multi" +) + +// New creates a logger identified by id. +// Returns: *slog.Logger, closeFunc, error. +// Call closeFunc() on shutdown to close open files. +func New(id string, cfg config.LogConfig) (*slog.Logger, func() error, error) { + if cfg.Directory == "" { + cfg.Directory = "logs" + } + if cfg.Rotation <= 0 { + cfg.Rotation = 7 + } + + console := slog.NewTextHandler(&prefixWriter{inner: os.Stderr, prefix: []byte("[" + id + "] "), startLine: true}, &slog.HandlerOptions{AddSource: true}) + router := newFileRouter(cfg.Directory, cfg.Rotation, id) + root := slogmulti.Fanout(console, router) + + return slog.New(root), router.CloseFiles, nil +} + +func (p *prefixWriter) Write(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + + totalWritten := 0 + if p.startLine { + n, err := p.inner.Write(p.prefix) + totalWritten += n + if err != nil { + return totalWritten, err + } + p.startLine = false + } + + n, err := p.inner.Write(b) + totalWritten += n + if err != nil { + return totalWritten, err + } + + if len(b) > 0 && b[len(b)-1] == '\n' { + p.startLine = true + } + return totalWritten, nil +} + +func newFileRouter(baseDir string, rotationDays int, id string) *fileRouter { + return &fileRouter{ + handlers: make(map[string]slog.Handler), + files: make(map[string]*os.File), + baseDir: baseDir, + rotationDays: rotationDays, + id: id, + dirTimeLayout: "2006-01-02", // e.g. 2025-11-30 + } +} + +//goland:noinspection GoUnusedParameter +func (r *fileRouter) Enabled(ctx context.Context, lvl slog.Level) bool { + // Conservatively true; actual handler will decide. + return true +} + +func (r *fileRouter) Handle(ctx context.Context, rec slog.Record) error { + now := time.Now() + dirName := now.Format(r.dirTimeLayout) + dirPath := filepath.Join(r.baseDir, dirName) + filePath := filepath.Join(dirPath, r.id+".log") + + h, err := r.getHandler(dirPath, filePath) + if err != nil { + return fmt.Errorf("file router get handler: %w", err) + } + + return h.Handle(ctx, rec) +} + +func (r *fileRouter) WithAttrs(attrs []slog.Attr) slog.Handler { + return r +} + +func (r *fileRouter) WithGroup(name string) slog.Handler { + return r +} + +// getHandler returns a text handler for the given file, creating dir/file and cleaning up old dirs if needed. +func (r *fileRouter) getHandler(dirPath, filePath string) (slog.Handler, error) { + r.mu.RLock() + h, ok := r.handlers[filePath] + r.mu.RUnlock() + if ok { + return h, nil + } + + r.mu.Lock() + defer r.mu.Unlock() + if h, ok = r.handlers[filePath]; ok { + return h, nil + } + + _, statErr := os.Stat(dirPath) + dirExisted := statErr == nil + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return nil, fmt.Errorf("mkdir %s: %w", dirPath, err) + } + + if !dirExisted { + if err := r.cleanupOldDirs(); err != nil { + // don't fail logging just because cleanup failed; report to stderr and continue. + _, _ = fmt.Fprintf(os.Stderr, "logger: cleanupOldDirs error: %v\n", err) + } + } + + f, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return nil, fmt.Errorf("open log file %s: %w", filePath, err) + } + + textHandler := slog.NewTextHandler(f, &slog.HandlerOptions{ + AddSource: false, + }) + + r.files[filePath] = f + r.handlers[filePath] = textHandler + return textHandler, nil +} + +// cleanupOldDirs scans r.baseDir for directories matching the dirTimeLayout and deletes any whose +// day-start time is older than rotationDays from now. +func (r *fileRouter) cleanupOldDirs() error { + entries, err := os.ReadDir(r.baseDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read baseDir: %w", err) + } + + now := time.Now() + cutoff := now.AddDate(0, 0, -r.rotationDays) + + for _, e := range entries { + if !e.IsDir() { + continue + } + name := e.Name() + t, err := time.ParseInLocation(r.dirTimeLayout, name, time.Local) + if err != nil { + continue + } + + if t.Before(cutoff) { + path := filepath.Join(r.baseDir, name) + + if err := os.RemoveAll(path); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "logger: failed to remove old log dir %s: %v\n", path, err) + } + } + } + + return nil +} + +// CloseFiles closes open files +func (r *fileRouter) CloseFiles() error { + r.mu.Lock() + defer r.mu.Unlock() + + var firstErr error + for p, f := range r.files { + if err := f.Close(); err != nil && firstErr == nil { + firstErr = err + } + delete(r.files, p) + delete(r.handlers, p) + } + return firstErr +} diff --git a/util/logger/structs.go b/util/logger/structs.go new file mode 100644 index 0000000..b9b8f2f --- /dev/null +++ b/util/logger/structs.go @@ -0,0 +1,27 @@ +package logger + +import ( + "io" + "log/slog" + "os" + "sync" +) + +type fileRouter struct { + mu sync.RWMutex + handlers map[string]slog.Handler // map[filePath]handler + files map[string]*os.File // map[filePath]*os.File to close later + baseDir string + rotationDays int + id string + dirTimeLayout string // "2006-01-02" - daily dirs +} + +// prefixWriter - writes a prefix at the start of each new line. +// It is safe for concurrent use. +type prefixWriter struct { + inner io.Writer + prefix []byte + mu sync.Mutex + startLine bool +} From 667db6be83dd07e4c4f7dab0559a0cd8b4709795 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 10:45:58 +0100 Subject: [PATCH 02/15] working websocket implementation --- .gitignore | 1 + config.toml | 8 +- go.mod | 1 + go.sum | 2 + main.go | 22 +---- shared/controller.go | 27 ++++++ shared/structs.go | 12 +++ util/config/structs.go | 4 +- ws/structs.go | 54 ++++++++++++ ws/util.go | 92 ++++++++++++++++++++ ws/websocket.go | 186 +++++++++++++++++++++++++++++++++++++++++ 11 files changed, 386 insertions(+), 23 deletions(-) create mode 100644 shared/controller.go create mode 100644 shared/structs.go create mode 100644 ws/structs.go create mode 100644 ws/util.go create mode 100644 ws/websocket.go diff --git a/.gitignore b/.gitignore index c480041..bf9d731 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ go.work.sum .idea /.dea/ +/logs/ diff --git a/config.toml b/config.toml index c9db740..f0eacef 100644 --- a/config.toml +++ b/config.toml @@ -1,11 +1,13 @@ [log] level = "info" directory = "logs/" -rotation = 3 # in days +rotation = 3 # in days [gateway] -http_port = "" -websocket = "" +http_port = 8080 +websocket = "gateway" +body_size = 2 # in MB +queue_max = 8192 [database] host_dsn = "" diff --git a/go.mod b/go.mod index f2b8ec3..086d25e 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( ) require ( + github.com/gorilla/websocket v1.5.3 // indirect github.com/samber/lo v1.52.0 // indirect github.com/samber/slog-common v0.19.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index b3a66c9..41997cc 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= diff --git a/main.go b/main.go index 116a53b..f0f544f 100644 --- a/main.go +++ b/main.go @@ -2,33 +2,17 @@ package main import ( "flag" - "fmt" + "homestead/homestead_gateway/shared" "homestead/homestead_gateway/util/config" - "homestead/homestead_gateway/util/logger" - "log/slog" - "os" ) func main() { cfgPath := flag.String("config", "config.toml", "configuration file") - cfg, err := config.LoadConfig(*cfgPath) if err != nil { panic(err) } - l, closeFn, err := logger.New("my-service", cfg.Log) - if err != nil { - panic(err) - } - - defer func() { - if err := closeFn(); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "error closing logs: %v\n", err) - } - }() - - slog.SetDefault(l) - l.Info("started", "port", 8080) - l.Error("something broke", "err", err) + controller := shared.NewGatewayController(*cfg) + controller.Run() } diff --git a/shared/controller.go b/shared/controller.go new file mode 100644 index 0000000..b27ecdb --- /dev/null +++ b/shared/controller.go @@ -0,0 +1,27 @@ +package shared + +import ( + "homestead/homestead_gateway/util/config" + "homestead/homestead_gateway/util/logger" + "homestead/homestead_gateway/ws" +) + +func NewGatewayController(cfg config.Config) GatewayController { + wsl, wCloseFn, err := logger.New("Websocket", cfg.Log) + if err != nil { + panic(err) + } + //hsl, hCloseFn, err := logger.New("HttpServer", cfg.Log) + //if err != nil { + // panic(err) + //} + + return GatewayController{ + Websocket: ws.NewWsGateway(cfg.Gateway, wsl, wCloseFn), + HttpServer: HttpGateway{}, + } +} + +func (gc *GatewayController) Run() { + gc.Websocket.StartGatewayWithForwarder() +} diff --git a/shared/structs.go b/shared/structs.go new file mode 100644 index 0000000..26357fb --- /dev/null +++ b/shared/structs.go @@ -0,0 +1,12 @@ +package shared + +import ( + "homestead/homestead_gateway/ws" +) + +type GatewayController struct { + Websocket *ws.WebsocketGateway + HttpServer HttpGateway +} + +type HttpGateway struct{} diff --git a/util/config/structs.go b/util/config/structs.go index 0491f08..4700c84 100644 --- a/util/config/structs.go +++ b/util/config/structs.go @@ -9,8 +9,10 @@ type Config struct { } type GatewayConfig struct { - HttpPort string `toml:"http_port"` + HttpPort int `toml:"http_port"` Websocket string `toml:"websocket"` + BodySize int `toml:"body_size"` + QueueSize int `toml:"queue_max"` } type LogConfig struct { diff --git a/ws/structs.go b/ws/structs.go new file mode 100644 index 0000000..0669da2 --- /dev/null +++ b/ws/structs.go @@ -0,0 +1,54 @@ +package ws + +import ( + "log/slog" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type WebsocketGateway struct { + port int + apiKey string + bodySizeBytes int64 + + upgrader websocket.Upgrader + modConnsMu sync.Mutex + modConns map[*websocket.Conn]struct{} + + botConnsMu sync.Mutex + botConns map[*websocket.Conn]*sync.Mutex + outgoingCh chan GatewayMessageOut + + logger *slog.Logger + closefn func() error +} + +type User struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type GatewayMessageIn struct { + MsgID string `json:"msg_id"` + Server string `json:"server"` + User User `json:"user"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) +} + +type GatewayMessageOut struct { + Type string `json:"type"` // "message" + Payload GatewayMessageIn `json:"payload"` + ForwardedBy string `json:"forwarded_by"` // "ws"/"http" + ForwardedAt time.Time `json:"forwarded_at"` +} + +type BotAck struct { + Type string `json:"type"` // "acknowledge" + MsgID string `json:"msg_id"` + Status string `json:"status,omitempty"` // "queued"/"sent" +} diff --git a/ws/util.go b/ws/util.go new file mode 100644 index 0000000..904932d --- /dev/null +++ b/ws/util.go @@ -0,0 +1,92 @@ +package ws + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os/signal" + "syscall" + "time" + + "github.com/gorilla/websocket" +) + +func (g *WebsocketGateway) StartGatewayWithForwarder() { + go func() { + for out := range g.Outgoing() { + b, _ := json.Marshal(out) + // use Info because these are normal forward events + g.logger.Info("forwarder -> bot", "msg", string(b)) + } + }() + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + if err := g.Serve(ctx, fmt.Sprintf(":%d", g.port)); err != nil { + g.logger.Error("gateway serve error", err) + } + close(g.outgoingCh) +} + +func (g *WebsocketGateway) Outgoing() <-chan GatewayMessageOut { + return g.outgoingCh +} + +func (g *WebsocketGateway) OutgoingLen() int { + return len(g.outgoingCh) +} + +// util + +func (g *WebsocketGateway) validateApiKey(r *http.Request) bool { + apiKey := r.URL.Query().Get("api_key") + if apiKey == "" { + apiKey = r.Header.Get("X-API-Key") + } + + return !(apiKey == "" || apiKey != g.apiKey) +} + +func writeJSONSafe(c *websocket.Conn, v interface{}) error { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteJSON(v); err != nil { + // caller handles logging + return err + } + return nil +} + +func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + logger.Info("http request", "remote", r.RemoteAddr, "method", r.Method, "path", r.URL.Path, "duration", time.Since(start)) + }) +} + +// connections + +func (g *WebsocketGateway) registerConn(c *websocket.Conn) { + g.modConnsMu.Lock() + g.modConns[c] = struct{}{} + g.modConnsMu.Unlock() +} + +func (g *WebsocketGateway) unregisterConn(c *websocket.Conn) { + g.modConnsMu.Lock() + delete(g.modConns, c) + g.modConnsMu.Unlock() +} + +func (g *WebsocketGateway) closeAll() { + g.modConnsMu.Lock() + defer g.modConnsMu.Unlock() + g.logger.Info("closing websocket connections") + for c := range g.modConns { + _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"), time.Now().Add(time.Second)) + _ = c.Close() + } +} diff --git a/ws/websocket.go b/ws/websocket.go new file mode 100644 index 0000000..0797a12 --- /dev/null +++ b/ws/websocket.go @@ -0,0 +1,186 @@ +package ws + +import ( + "context" + "encoding/json" + "errors" + "homestead/homestead_gateway/util/config" + "log/slog" + "net" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" +) + +func (m *GatewayMessageIn) Validate() error { + if strings.TrimSpace(m.MsgID) == "" { + return errors.New("msg_id missing") + } + if strings.TrimSpace(m.Server) == "" { + return errors.New("server missing") + } + if strings.TrimSpace(m.User.ID) == "" { + return errors.New("user.mod_uid missing") + } + if strings.TrimSpace(m.Content) == "" { + return errors.New("content missing") + } + return nil +} + +func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { + return &WebsocketGateway{ + logger: logger, + closefn: closefn, + apiKey: cfg.Websocket, + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // local by default; change for production + }, + }, + outgoingCh: make(chan GatewayMessageOut, cfg.QueueSize), + modConns: make(map[*websocket.Conn]struct{}), + bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, + port: cfg.HttpPort, + } +} + +// Serve starts the HTTP server and /ws endpoint and blocks until ctx cancelled or server fails. +func (g *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { + mux := http.NewServeMux() + mux.HandleFunc("/ws", g.handleWS) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, _ = w.Write([]byte("ok")) + }) + + srv := &http.Server{ + Addr: listenAddr, + Handler: loggingMiddleware(g.logger, mux), + BaseContext: func(l net.Listener) context.Context { return ctx }, + } + + errCh := make(chan error, 1) + go func() { + g.logger.Info("ws gateway listening", "addr", listenAddr) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errCh <- err + } + close(errCh) + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + g.logger.Info("shutting down http server") + _ = srv.Shutdown(shutdownCtx) + g.closeAll() + return nil + case err := <-errCh: + return err + } +} + +func (g *WebsocketGateway) handleWS(w http.ResponseWriter, r *http.Request) { + if !g.validateApiKey(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + g.logger.Warn("ws auth failed", "remote", r.RemoteAddr) + return + } + + conn, err := g.upgrader.Upgrade(w, r, nil) + if err != nil { + g.logger.Error("ws upgrade error", err, "remote", r.RemoteAddr) + return + } + + g.registerConn(conn) + g.logger.Info("ws connected", "remote", conn.RemoteAddr().String()) + + // Configure read limits & pong handler + if g.bodySizeBytes > 0 { + conn.SetReadLimit(g.bodySizeBytes) + } else { + conn.SetReadLimit(1 << 20) // sensible default 1MiB + } + + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + go g.readLoop(conn) +} + +func (g *WebsocketGateway) readLoop(c *websocket.Conn) { + defer func() { + g.unregisterConn(c) + _ = c.Close() + g.logger.Info("ws disconnected", "remote", c.RemoteAddr().String()) + }() + + pingTicker := time.NewTicker(30 * time.Second) + defer pingTicker.Stop() + + for { + // Read one message (blocks until message arrives) + typ, data, err := c.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + g.logger.Warn("unexpected ws close", "err", err) + } else { + g.logger.Debug("ws read error", "err", err) + } + return + } + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { + continue + } + + var in GatewayMessageIn + if err := json.Unmarshal(data, &in); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) + g.logger.Warn("invalid json from client", "remote", c.RemoteAddr().String(), "err", err) + continue + } + in.ReceivedAt = time.Now().UTC() + if err := in.Validate(); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) + g.logger.Warn("message validation failed", "remote", c.RemoteAddr().String(), "err", err) + continue + } + + out := GatewayMessageOut{ + Type: "message", + Payload: in, + ForwardedAt: time.Now().UTC(), + } + + // Non-blocking enqueue with backpressure + select { + case g.outgoingCh <- out: + _ = writeJSONSafe(c, map[string]string{"status": "queued"}) + g.logger.Debug("enqueued message", "msg_id", in.MsgID, "server", in.Server) + default: + _ = writeJSONSafe(c, map[string]string{"error": "gateway busy"}) + g.logger.Warn("outgoing queue full", "msg_id", in.MsgID, "remote", c.RemoteAddr().String()) + } + + // also handle pings periodically (so client sees ping frequently) + select { + case <-pingTicker.C: + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { + g.logger.Debug("write ping failed", "err", err) + return + } + default: + } + } +} From aac6db39c24f01876d8781894dd53234ac47ab47 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 10:47:37 +0100 Subject: [PATCH 03/15] updated GatewayMessageOut with dest --- ws/structs.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ws/structs.go b/ws/structs.go index 0669da2..2bf768c 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -30,6 +30,10 @@ type User struct { Name string `json:"name"` } +type Destination struct { + Channel string `json:"channel"` +} + type GatewayMessageIn struct { MsgID string `json:"msg_id"` Server string `json:"server"` @@ -43,6 +47,7 @@ type GatewayMessageIn struct { type GatewayMessageOut struct { Type string `json:"type"` // "message" Payload GatewayMessageIn `json:"payload"` + Destination Destination `json:"destination"` ForwardedBy string `json:"forwarded_by"` // "ws"/"http" ForwardedAt time.Time `json:"forwarded_at"` } From 4a668493c4274fc8078119891b7d1ad2f6c7b694 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 13:40:58 +0100 Subject: [PATCH 04/15] temp push --- config.toml | 4 +- {shared => controller}/controller.go | 15 +- {shared => controller}/structs.go | 2 +- main.go | 24 +++- sim.go | 191 ++++++++++++++++++++++++++ util/logger/log.go | 2 +- ws/handlers.go | 71 ++++++++++ ws/structs.go | 87 +++++++++--- ws/util.go | 44 ++---- ws/validate.go | 38 ++++++ ws/websocket.go | 197 ++++++++++++++++++++------- 11 files changed, 559 insertions(+), 116 deletions(-) rename {shared => controller}/controller.go (57%) rename {shared => controller}/structs.go (89%) create mode 100644 sim.go create mode 100644 ws/handlers.go create mode 100644 ws/validate.go diff --git a/config.toml b/config.toml index f0eacef..b41f343 100644 --- a/config.toml +++ b/config.toml @@ -1,10 +1,10 @@ [log] -level = "info" +level = "debug" directory = "logs/" rotation = 3 # in days [gateway] -http_port = 8080 +http_port = 3333 websocket = "gateway" body_size = 2 # in MB queue_max = 8192 diff --git a/shared/controller.go b/controller/controller.go similarity index 57% rename from shared/controller.go rename to controller/controller.go index b27ecdb..285aa66 100644 --- a/shared/controller.go +++ b/controller/controller.go @@ -1,4 +1,4 @@ -package shared +package controller import ( "homestead/homestead_gateway/util/config" @@ -11,17 +11,16 @@ func NewGatewayController(cfg config.Config) GatewayController { if err != nil { panic(err) } - //hsl, hCloseFn, err := logger.New("HttpServer", cfg.Log) - //if err != nil { - // panic(err) - //} + + modHandler := ws.NewLoggingModHandler(wsl) + botHandler := ws.NewLoggingBotHandler(wsl) return GatewayController{ - Websocket: ws.NewWsGateway(cfg.Gateway, wsl, wCloseFn), + Websocket: ws.NewWsGateway(cfg.Gateway, wsl, wCloseFn, modHandler, botHandler), HttpServer: HttpGateway{}, } } -func (gc *GatewayController) Run() { - gc.Websocket.StartGatewayWithForwarder() +func (gc *GatewayController) Run() error { + return gc.Websocket.Start() } diff --git a/shared/structs.go b/controller/structs.go similarity index 89% rename from shared/structs.go rename to controller/structs.go index 26357fb..e56093b 100644 --- a/shared/structs.go +++ b/controller/structs.go @@ -1,4 +1,4 @@ -package shared +package controller import ( "homestead/homestead_gateway/ws" diff --git a/main.go b/main.go index f0f544f..351e983 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,10 @@ package main import ( "flag" - "homestead/homestead_gateway/shared" + "fmt" + "homestead/homestead_gateway/controller" "homestead/homestead_gateway/util/config" + "homestead/homestead_gateway/util/logger" ) func main() { @@ -13,6 +15,22 @@ func main() { panic(err) } - controller := shared.NewGatewayController(*cfg) - controller.Run() + l, c, e := logger.New("yomama", cfg.Log) + if e != nil { + panic(e) + } + defer c() + + l.Debug("debug") + l.Info("info") + l.Warn("warn") + l.Error("error") + + panic(fmt.Sprintf("%+v", cfg.Log)) + + ctrl := controller.NewGatewayController(*cfg) + err = ctrl.Run() + if err != nil { + panic(err) + } } diff --git a/sim.go b/sim.go new file mode 100644 index 0000000..4ea529d --- /dev/null +++ b/sim.go @@ -0,0 +1,191 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/url" + "time" + + "github.com/gorilla/websocket" +) + +type MessageEnvelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +type ModHandshake struct { + ServerID string `json:"server_id"` +} + +type BotHandshake struct { + BotID string `json:"bot_id"` +} + +type User struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type ModMessage struct { + MsgID string `json:"msg_id"` + Server string `json:"server"` + User User `json:"user"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` +} + +type BotMessage struct { + MsgID string `json:"msg_id"` + ChannelID string `json:"channel_id"` + Author string `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` +} + +func simulateMod() { + u := url.URL{Scheme: "ws", Host: "localhost:3333", Path: "/ws"} + u.RawQuery = "api_key=gateway" + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("mod dial error: %v", err) + } + defer conn.Close() + + fmt.Println("[MOD] Connected") + + // Send handshake + handshake := ModHandshake{ServerID: "survival_server"} + hsData, _ := json.Marshal(handshake) + envelope := MessageEnvelope{ + Type: "mod", + Data: hsData, + } + envData, _ := json.Marshal(envelope) + + if err := conn.WriteMessage(websocket.TextMessage, envData); err != nil { + log.Fatalf("mod handshake write error: %v", err) + } + fmt.Println("[MOD] Sent handshake") + + // Read handshake response + _, resp, err := conn.ReadMessage() + if err != nil { + log.Fatalf("mod handshake response error: %v", err) + } + fmt.Printf("[MOD] Handshake response: %s\n", string(resp)) + + // Send a few messages + for i := 1; i <= 3; i++ { + msg := ModMessage{ + MsgID: fmt.Sprintf("mod_msg_%d", i), + Server: "survival_server", + User: User{ID: "player_123", Name: "Steve"}, + Content: fmt.Sprintf("Message %d from Minecraft!", i), + Meta: map[string]interface{}{ + "coordinates": map[string]int{"x": 100 + i*10, "y": 64, "z": 200}, + }, + Ts: time.Now().UTC().Format(time.RFC3339), + } + + msgData, _ := json.Marshal(msg) + if err := conn.WriteMessage(websocket.TextMessage, msgData); err != nil { + log.Fatalf("mod message write error: %v", err) + } + fmt.Printf("[MOD] Sent message: %s\n", msg.Content) + + // Read response + _, resp, err := conn.ReadMessage() + if err != nil { + log.Fatalf("mod response error: %v", err) + } + fmt.Printf("[MOD] Response: %s\n", string(resp)) + + time.Sleep(1 * time.Second) + } + + fmt.Println("[MOD] Closing connection") +} + +func simulateBot() { + time.Sleep(2 * time.Second) // Let mod connect first + + u := url.URL{Scheme: "ws", Host: "localhost:3333", Path: "/ws"} + u.RawQuery = "api_key=gateway" + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("bot dial error: %v", err) + } + defer conn.Close() + + fmt.Println("[BOT] Connected") + + // Send handshake + handshake := BotHandshake{BotID: "discord_bot_1"} + hsData, _ := json.Marshal(handshake) + envelope := MessageEnvelope{ + Type: "bot", + Data: hsData, + } + envData, _ := json.Marshal(envelope) + + if err := conn.WriteMessage(websocket.TextMessage, envData); err != nil { + log.Fatalf("bot handshake write error: %v", err) + } + fmt.Println("[BOT] Sent handshake") + + // Read handshake response + _, resp, err := conn.ReadMessage() + if err != nil { + log.Fatalf("bot handshake response error: %v", err) + } + fmt.Printf("[BOT] Handshake response: %s\n", string(resp)) + + // Send a few messages + for i := 1; i <= 3; i++ { + msg := BotMessage{ + MsgID: fmt.Sprintf("bot_msg_%d", i), + ChannelID: "987654321", + Author: "DiscordUser#1234", + Content: fmt.Sprintf("Message %d from Discord!", i), + Meta: map[string]interface{}{ + "reactions": []string{"👍", "❤️"}, + }, + Ts: time.Now().UTC().Format(time.RFC3339), + } + + msgData, _ := json.Marshal(msg) + if err := conn.WriteMessage(websocket.TextMessage, msgData); err != nil { + log.Fatalf("bot message write error: %v", err) + } + fmt.Printf("[BOT] Sent message: %s\n", msg.Content) + + // Read response + _, resp, err := conn.ReadMessage() + if err != nil { + log.Fatalf("bot response error: %v", err) + } + fmt.Printf("[BOT] Response: %s\n", string(resp)) + + time.Sleep(1 * time.Second) + } + + fmt.Println("[BOT] Closing connection") +} + +func main() { + fmt.Println("Starting WebSocket client simulator...") + fmt.Println("Connecting to ws://localhost:3333/ws with api_key=test_key") + + go simulateMod() + go simulateBot() + + // Let them run + time.Sleep(15 * time.Second) + fmt.Println("Simulator finished") +} diff --git a/util/logger/log.go b/util/logger/log.go index 72f7a71..75e5397 100644 --- a/util/logger/log.go +++ b/util/logger/log.go @@ -23,7 +23,7 @@ func New(id string, cfg config.LogConfig) (*slog.Logger, func() error, error) { cfg.Rotation = 7 } - console := slog.NewTextHandler(&prefixWriter{inner: os.Stderr, prefix: []byte("[" + id + "] "), startLine: true}, &slog.HandlerOptions{AddSource: true}) + console := slog.NewTextHandler(&prefixWriter{inner: os.Stderr, prefix: []byte("[" + id + "] "), startLine: true}, &slog.HandlerOptions{AddSource: true, Level: cfg.Level}) router := newFileRouter(cfg.Directory, cfg.Rotation, id) root := slogmulti.Fanout(console, router) diff --git a/ws/handlers.go b/ws/handlers.go new file mode 100644 index 0000000..238efbd --- /dev/null +++ b/ws/handlers.go @@ -0,0 +1,71 @@ +// handlers.go +package ws + +import ( + "context" + "encoding/json" + "log/slog" + "time" +) + +type LoggingModHandler struct { + logger *slog.Logger +} + +type LoggingBotHandler struct { + logger *slog.Logger +} + +func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { + return &LoggingModHandler{logger: logger} +} + +func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { + return &LoggingBotHandler{logger: logger} +} + +func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) error { + // For now, just log and pretend it's being forwarded + // TODO: Look up channel_id from database using server + // TODO: Forward to bot connection(s) + + fwd := ForwardedModMessage{ + Type: "mod", + ServerID: msg.Server, + ChannelID: "TODO", // will come from database lookup + User: msg.User, + Content: msg.Content, + Meta: msg.Meta, + Ts: msg.Ts, + ReceivedAt: msg.ReceivedAt, + ForwardedAt: time.Now().UTC(), + } + + b, _ := json.Marshal(fwd) + h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) + + return nil +} + +func (h *LoggingBotHandler) Handle(ctx context.Context, msg GatewayBotMessageIn) error { + // For now, just log and pretend it's being forwarded + // TODO: Look up server_id from database using channel_id + // TODO: Forward to mod connection(s) + + fwd := ForwardedBotMessage{ + Type: "bot", + ServerID: "TODO", // will come from database lookup + ChannelID: msg.ChannelID, + Author: msg.Author, + Content: msg.Content, + Meta: msg.Meta, + Ts: msg.Ts, + ReceivedAt: msg.ReceivedAt, + ForwardedAt: time.Now().UTC(), + } + + b, _ := json.Marshal(fwd) + h.logger.Debug("forwarding bot message", "msg_id", msg.MsgID, "channel", msg.ChannelID, "payload", string(b)) + + return nil +} diff --git a/ws/structs.go b/ws/structs.go index 2bf768c..bba61be 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -1,6 +1,8 @@ package ws import ( + "context" + "encoding/json" "log/slog" "sync" "time" @@ -13,13 +15,12 @@ type WebsocketGateway struct { apiKey string bodySizeBytes int64 - upgrader websocket.Upgrader - modConnsMu sync.Mutex - modConns map[*websocket.Conn]struct{} + upgrader websocket.Upgrader + connsMu sync.Mutex + conns map[*websocket.Conn]connMetadata - botConnsMu sync.Mutex - botConns map[*websocket.Conn]*sync.Mutex - outgoingCh chan GatewayMessageOut + modHandler ModHandler + botHandler BotHandler logger *slog.Logger closefn func() error @@ -30,11 +31,8 @@ type User struct { Name string `json:"name"` } -type Destination struct { - Channel string `json:"channel"` -} - -type GatewayMessageIn struct { +// GatewayModMessageIn : Mod -> Gateway -> Bot +type GatewayModMessageIn struct { MsgID string `json:"msg_id"` Server string `json:"server"` User User `json:"user"` @@ -44,12 +42,41 @@ type GatewayMessageIn struct { ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) } -type GatewayMessageOut struct { - Type string `json:"type"` // "message" - Payload GatewayMessageIn `json:"payload"` - Destination Destination `json:"destination"` - ForwardedBy string `json:"forwarded_by"` // "ws"/"http" - ForwardedAt time.Time `json:"forwarded_at"` +// GatewayBotMessageIn : Bot -> Gateway -> Mod +type GatewayBotMessageIn struct { + MsgID string `json:"msg_id"` + ChannelID string `json:"channel_id"` + Author string `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` +} + +// ForwardedModMessage : Gateway -> Bot +type ForwardedModMessage struct { + Type string `json:"type"` // "mod" + ServerID string `json:"server_id"` + ChannelID string `json:"channel_id"` + User User `json:"user"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` + ReceivedAt time.Time `json:"received_at"` + ForwardedAt time.Time `json:"forwarded_at"` +} + +// ForwardedBotMessage : Gateway -> Mod +type ForwardedBotMessage struct { + Type string `json:"type"` // "bot" + ServerID string `json:"server_id"` + ChannelID string `json:"channel_id"` + Author string `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` + ReceivedAt time.Time `json:"received_at"` + ForwardedAt time.Time `json:"forwarded_at"` } type BotAck struct { @@ -57,3 +84,29 @@ type BotAck struct { MsgID string `json:"msg_id"` Status string `json:"status,omitempty"` // "queued"/"sent" } + +type MessageEnvelope struct { + Type string `json:"type"` // "mod" or "bot" + Data json.RawMessage `json:"data"` +} + +type ModHandler interface { + Handle(ctx context.Context, msg GatewayModMessageIn) error +} + +type BotHandler interface { + Handle(ctx context.Context, msg GatewayBotMessageIn) error +} + +type ModHandshake struct { + ServerID string `json:"server_id"` +} + +type BotHandshake struct { + BotID string `json:"bot_id"` +} + +type connMetadata struct { + connType string // "mod" or "bot" + id string // server_id or bot_id for logging +} diff --git a/ws/util.go b/ws/util.go index 904932d..aeeed09 100644 --- a/ws/util.go +++ b/ws/util.go @@ -2,7 +2,6 @@ package ws import ( "context" - "encoding/json" "fmt" "log/slog" "net/http" @@ -13,30 +12,11 @@ import ( "github.com/gorilla/websocket" ) -func (g *WebsocketGateway) StartGatewayWithForwarder() { - go func() { - for out := range g.Outgoing() { - b, _ := json.Marshal(out) - // use Info because these are normal forward events - g.logger.Info("forwarder -> bot", "msg", string(b)) - } - }() - +func (g *WebsocketGateway) Start() error { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - if err := g.Serve(ctx, fmt.Sprintf(":%d", g.port)); err != nil { - g.logger.Error("gateway serve error", err) - } - close(g.outgoingCh) -} - -func (g *WebsocketGateway) Outgoing() <-chan GatewayMessageOut { - return g.outgoingCh -} - -func (g *WebsocketGateway) OutgoingLen() int { - return len(g.outgoingCh) + return g.Serve(ctx, fmt.Sprintf(":%d", g.port)) } // util @@ -69,23 +49,23 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { // connections -func (g *WebsocketGateway) registerConn(c *websocket.Conn) { - g.modConnsMu.Lock() - g.modConns[c] = struct{}{} - g.modConnsMu.Unlock() +func (g *WebsocketGateway) registerConn(c *websocket.Conn, meta connMetadata) { + g.connsMu.Lock() + g.conns[c] = meta + g.connsMu.Unlock() } func (g *WebsocketGateway) unregisterConn(c *websocket.Conn) { - g.modConnsMu.Lock() - delete(g.modConns, c) - g.modConnsMu.Unlock() + g.connsMu.Lock() + delete(g.conns, c) + g.connsMu.Unlock() } func (g *WebsocketGateway) closeAll() { - g.modConnsMu.Lock() - defer g.modConnsMu.Unlock() + g.connsMu.Lock() + defer g.connsMu.Unlock() g.logger.Info("closing websocket connections") - for c := range g.modConns { + for c := range g.conns { _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"), time.Now().Add(time.Second)) _ = c.Close() } diff --git a/ws/validate.go b/ws/validate.go new file mode 100644 index 0000000..f8feff6 --- /dev/null +++ b/ws/validate.go @@ -0,0 +1,38 @@ +package ws + +import ( + "errors" + "strings" +) + +func (m *GatewayModMessageIn) Validate() error { + if strings.TrimSpace(m.MsgID) == "" { + return errors.New("msg_id missing") + } + if strings.TrimSpace(m.Server) == "" { + return errors.New("server missing") + } + if strings.TrimSpace(m.User.ID) == "" { + return errors.New("user.id missing") + } + if strings.TrimSpace(m.Content) == "" { + return errors.New("content missing") + } + return nil +} + +func (m *GatewayBotMessageIn) Validate() error { + if strings.TrimSpace(m.MsgID) == "" { + return errors.New("msg_id missing") + } + if strings.TrimSpace(m.ChannelID) == "" { + return errors.New("channel_id missing") + } + if strings.TrimSpace(m.Author) == "" { + return errors.New("author missing") + } + if strings.TrimSpace(m.Content) == "" { + return errors.New("content missing") + } + return nil +} diff --git a/ws/websocket.go b/ws/websocket.go index 0797a12..7cc062b 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -8,29 +8,12 @@ import ( "log/slog" "net" "net/http" - "strings" "time" "github.com/gorilla/websocket" ) -func (m *GatewayMessageIn) Validate() error { - if strings.TrimSpace(m.MsgID) == "" { - return errors.New("msg_id missing") - } - if strings.TrimSpace(m.Server) == "" { - return errors.New("server missing") - } - if strings.TrimSpace(m.User.ID) == "" { - return errors.New("user.mod_uid missing") - } - if strings.TrimSpace(m.Content) == "" { - return errors.New("content missing") - } - return nil -} - -func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { +func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { return &WebsocketGateway{ logger: logger, closefn: closefn, @@ -42,14 +25,14 @@ func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() return true // local by default; change for production }, }, - outgoingCh: make(chan GatewayMessageOut, cfg.QueueSize), - modConns: make(map[*websocket.Conn]struct{}), + conns: make(map[*websocket.Conn]connMetadata), bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, port: cfg.HttpPort, + modHandler: modH, + botHandler: botH, } } -// Serve starts the HTTP server and /ws endpoint and blocks until ctx cancelled or server fails. func (g *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { mux := http.NewServeMux() mux.HandleFunc("/ws", g.handleWS) @@ -67,7 +50,7 @@ func (g *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { errCh := make(chan error, 1) go func() { g.logger.Info("ws gateway listening", "addr", listenAddr) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- err } close(errCh) @@ -99,9 +82,6 @@ func (g *WebsocketGateway) handleWS(w http.ResponseWriter, r *http.Request) { return } - g.registerConn(conn) - g.logger.Info("ws connected", "remote", conn.RemoteAddr().String()) - // Configure read limits & pong handler if g.bodySizeBytes > 0 { conn.SetReadLimit(g.bodySizeBytes) @@ -115,69 +95,182 @@ func (g *WebsocketGateway) handleWS(w http.ResponseWriter, r *http.Request) { return nil }) - go g.readLoop(conn) + // First message must be a handshake identifying the connection type + typ, data, err := conn.ReadMessage() + if err != nil { + g.logger.Error("failed to read handshake", "err", err, "remote", conn.RemoteAddr().String()) + _ = conn.Close() + return + } + + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { + g.logger.Warn("invalid handshake message type", "remote", conn.RemoteAddr().String()) + _ = writeJSONSafe(conn, map[string]string{"error": "first message must be handshake"}) + _ = conn.Close() + return + } + + var envelope MessageEnvelope + if err := json.Unmarshal(data, &envelope); err != nil { + g.logger.Warn("invalid handshake json", "err", err, "remote", conn.RemoteAddr().String()) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid handshake: " + err.Error()}) + _ = conn.Close() + return + } + + meta := connMetadata{connType: envelope.Type} + + // Validate handshake based on type + switch envelope.Type { + case "mod": + var hs ModHandshake + if err := json.Unmarshal(envelope.Data, &hs); err != nil { + g.logger.Warn("invalid mod handshake", "err", err, "remote", conn.RemoteAddr().String()) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid mod handshake: " + err.Error()}) + _ = conn.Close() + return + } + meta.id = hs.ServerID + g.registerConn(conn, meta) + g.logger.Info("mod connected", "server_id", hs.ServerID, "remote", conn.RemoteAddr().String()) + go g.modReadLoop(conn, meta) + + case "bot": + var hs BotHandshake + if err := json.Unmarshal(envelope.Data, &hs); err != nil { + g.logger.Warn("invalid bot handshake", "err", err, "remote", conn.RemoteAddr().String()) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid bot handshake: " + err.Error()}) + _ = conn.Close() + return + } + meta.id = hs.BotID + g.registerConn(conn, meta) + g.logger.Info("bot connected", "bot_id", hs.BotID, "remote", conn.RemoteAddr().String()) + go g.botReadLoop(conn, meta) + + default: + g.logger.Warn("unknown connection type", "type", envelope.Type, "remote", conn.RemoteAddr().String()) + _ = writeJSONSafe(conn, map[string]string{"error": "unknown connection type: " + envelope.Type}) + _ = conn.Close() + return + } } -func (g *WebsocketGateway) readLoop(c *websocket.Conn) { +func (g *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connMetadata) { defer func() { g.unregisterConn(c) _ = c.Close() - g.logger.Info("ws disconnected", "remote", c.RemoteAddr().String()) + g.logger.Info("mod disconnected", "server_id", meta.id, "remote", c.RemoteAddr().String()) }() pingTicker := time.NewTicker(30 * time.Second) defer pingTicker.Stop() for { - // Read one message (blocks until message arrives) typ, data, err := c.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - g.logger.Warn("unexpected ws close", "err", err) + g.logger.Warn("unexpected mod close", "server_id", meta.id, "err", err) } else { - g.logger.Debug("ws read error", "err", err) + g.logger.Debug("mod read error", "server_id", meta.id, "err", err) } return } + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { continue } - var in GatewayMessageIn - if err := json.Unmarshal(data, &in); err != nil { + var msg GatewayModMessageIn + if err := json.Unmarshal(data, &msg); err != nil { _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - g.logger.Warn("invalid json from client", "remote", c.RemoteAddr().String(), "err", err) + g.logger.Warn("invalid json from mod", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) continue } - in.ReceivedAt = time.Now().UTC() - if err := in.Validate(); err != nil { + + msg.ReceivedAt = time.Now().UTC() + if err := msg.Validate(); err != nil { _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - g.logger.Warn("message validation failed", "remote", c.RemoteAddr().String(), "err", err) + g.logger.Warn("mod message validation failed", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) continue } - out := GatewayMessageOut{ - Type: "message", - Payload: in, - ForwardedAt: time.Now().UTC(), + // Handle the message (forward to bot, enrich, etc.) + if err := g.modHandler.Handle(context.Background(), msg); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) + g.logger.Error("mod handler error", "server_id", meta.id, "err", err) + continue } - // Non-blocking enqueue with backpressure - select { - case g.outgoingCh <- out: - _ = writeJSONSafe(c, map[string]string{"status": "queued"}) - g.logger.Debug("enqueued message", "msg_id", in.MsgID, "server", in.Server) - default: - _ = writeJSONSafe(c, map[string]string{"error": "gateway busy"}) - g.logger.Warn("outgoing queue full", "msg_id", in.MsgID, "remote", c.RemoteAddr().String()) - } + _ = writeJSONSafe(c, map[string]string{"status": "ok"}) - // also handle pings periodically (so client sees ping frequently) + // Handle pings select { case <-pingTicker.C: _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - g.logger.Debug("write ping failed", "err", err) + g.logger.Debug("write ping failed", "server_id", meta.id, "err", err) + return + } + default: + } + } +} + +func (g *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connMetadata) { + defer func() { + g.unregisterConn(c) + _ = c.Close() + g.logger.Info("bot disconnected", "bot_id", meta.id, "remote", c.RemoteAddr().String()) + }() + + pingTicker := time.NewTicker(30 * time.Second) + defer pingTicker.Stop() + + for { + typ, data, err := c.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + g.logger.Warn("unexpected bot close", "bot_id", meta.id, "err", err) + } else { + g.logger.Debug("bot read error", "bot_id", meta.id, "err", err) + } + return + } + + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { + continue + } + + var msg GatewayBotMessageIn + if err := json.Unmarshal(data, &msg); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) + g.logger.Warn("invalid json from bot", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + continue + } + + msg.ReceivedAt = time.Now().UTC() + if err := msg.Validate(); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) + g.logger.Warn("bot message validation failed", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + continue + } + + // Handle the message (forward to mod, enrich, etc.) + if err := g.botHandler.Handle(context.Background(), msg); err != nil { + _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) + g.logger.Error("bot handler error", "bot_id", meta.id, "err", err) + continue + } + + _ = writeJSONSafe(c, map[string]string{"status": "ok"}) + + // Handle pings + select { + case <-pingTicker.C: + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { + g.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) return } default: From 46aef47e2194e25c8e6a2bc88fa1d5cf45ac0fdf Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 18:09:10 +0100 Subject: [PATCH 05/15] re-formatted, better structure --- README.md | 14 +++- main.go | 15 ---- sim.go | 2 +- util/cache/cache.go | 70 ++++++++++++++++ util/cache/structs.go | 16 ++++ ws/handlers.go | 70 ---------------- ws/structs.go | 81 ++++++++++-------- ws/util.go | 104 ++++++++++++++++++----- ws/websocket.go | 191 ++++++++++++++++++++++-------------------- 9 files changed, 328 insertions(+), 235 deletions(-) create mode 100644 util/cache/cache.go create mode 100644 util/cache/structs.go diff --git a/README.md b/README.md index e4f35b4..f03f694 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,15 @@ # HomesteadGateway -Gateway between multiple HomesteadRelay's and the HomesteadToGo Bot. \ No newline at end of file +Gateway between multiple HomesteadRelay's and the HomesteadToGo Bot. + + +## dev notes + +perhaps drop database, instead + +``` +Mod -> websocket /register { server_id, channel_id } // grabbed from mod config +Bot -> websocket /ready { channel_id } // ready if Mod with fitting channel_id has called /register +Gateway -> memory cache (server_id -> channel_id; channel_id -> server_id) // mem enough +Mod/Bot -> websocket /ws { ... } -> Bot/Mod // sync +``` diff --git a/main.go b/main.go index 351e983..fa5670f 100644 --- a/main.go +++ b/main.go @@ -2,10 +2,8 @@ package main import ( "flag" - "fmt" "homestead/homestead_gateway/controller" "homestead/homestead_gateway/util/config" - "homestead/homestead_gateway/util/logger" ) func main() { @@ -15,19 +13,6 @@ func main() { panic(err) } - l, c, e := logger.New("yomama", cfg.Log) - if e != nil { - panic(e) - } - defer c() - - l.Debug("debug") - l.Info("info") - l.Warn("warn") - l.Error("error") - - panic(fmt.Sprintf("%+v", cfg.Log)) - ctrl := controller.NewGatewayController(*cfg) err = ctrl.Run() if err != nil { diff --git a/sim.go b/sim.go index 4ea529d..a8257e6 100644 --- a/sim.go +++ b/sim.go @@ -180,7 +180,7 @@ func simulateBot() { func main() { fmt.Println("Starting WebSocket client simulator...") - fmt.Println("Connecting to ws://localhost:3333/ws with api_key=test_key") + fmt.Println("Connecting to ws://localhost:3333/ws with api_key=gateway") go simulateMod() go simulateBot() diff --git a/util/cache/cache.go b/util/cache/cache.go new file mode 100644 index 0000000..9af7ca3 --- /dev/null +++ b/util/cache/cache.go @@ -0,0 +1,70 @@ +package cache + +import "sync" + +type Cache struct { + mu sync.RWMutex + s2c map[string]string + c2s map[string]string +} + +func NewCache() *Cache { + return &Cache{ + s2c: make(map[string]string), + c2s: make(map[string]string), + } +} + +// Set creates or overwrites the pair a -> b and b -> a. +// It ensures any previous mappings involving a or b are removed first. +func (c *Cache) Set(serverId, channelId string) { + c.mu.Lock() + defer c.mu.Unlock() + + if old, ok := c.s2c[serverId]; ok && old != channelId { + delete(c.c2s, old) + } + + if old, ok := c.c2s[channelId]; ok && old != serverId { + delete(c.s2c, old) + } + + c.s2c[serverId] = channelId + c.c2s[channelId] = serverId +} + +func (c *Cache) GetByServerId(serverId string) (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + cId, ok := c.s2c[serverId] + return cId, ok +} + +func (c *Cache) GetByChannelId(channelId string) (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + sId, ok := c.c2s[channelId] + return sId, ok +} + +func (c *Cache) RemoveByServerId(serverId string) { + c.mu.RLock() + defer c.mu.RUnlock() + + if channelId, ok := c.s2c[serverId]; ok { + delete(c.s2c, serverId) + delete(c.c2s, channelId) + } +} + +func (c *Cache) RemoveByChannelId(channelId string) { + c.mu.RLock() + defer c.mu.RUnlock() + + if serverId, ok := c.s2c[channelId]; ok { + delete(c.c2s, channelId) + delete(c.s2c, serverId) + } +} diff --git a/util/cache/structs.go b/util/cache/structs.go new file mode 100644 index 0000000..b5ce12f --- /dev/null +++ b/util/cache/structs.go @@ -0,0 +1,16 @@ +package cache + +import ( + "sync" +) + +type ShardedCache struct { + shards []*shard + shardMask uint32 +} + +type shard struct { + mu sync.RWMutex + s2c map[string]string // serverId -> channelId + c2s map[string]string // channelId -> serverId +} diff --git a/ws/handlers.go b/ws/handlers.go index 238efbd..9859295 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -1,71 +1 @@ -// handlers.go package ws - -import ( - "context" - "encoding/json" - "log/slog" - "time" -) - -type LoggingModHandler struct { - logger *slog.Logger -} - -type LoggingBotHandler struct { - logger *slog.Logger -} - -func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { - return &LoggingModHandler{logger: logger} -} - -func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { - return &LoggingBotHandler{logger: logger} -} - -func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) error { - // For now, just log and pretend it's being forwarded - // TODO: Look up channel_id from database using server - // TODO: Forward to bot connection(s) - - fwd := ForwardedModMessage{ - Type: "mod", - ServerID: msg.Server, - ChannelID: "TODO", // will come from database lookup - User: msg.User, - Content: msg.Content, - Meta: msg.Meta, - Ts: msg.Ts, - ReceivedAt: msg.ReceivedAt, - ForwardedAt: time.Now().UTC(), - } - - b, _ := json.Marshal(fwd) - h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) - - return nil -} - -func (h *LoggingBotHandler) Handle(ctx context.Context, msg GatewayBotMessageIn) error { - // For now, just log and pretend it's being forwarded - // TODO: Look up server_id from database using channel_id - // TODO: Forward to mod connection(s) - - fwd := ForwardedBotMessage{ - Type: "bot", - ServerID: "TODO", // will come from database lookup - ChannelID: msg.ChannelID, - Author: msg.Author, - Content: msg.Content, - Meta: msg.Meta, - Ts: msg.Ts, - ReceivedAt: msg.ReceivedAt, - ForwardedAt: time.Now().UTC(), - } - - b, _ := json.Marshal(fwd) - h.logger.Debug("forwarding bot message", "msg_id", msg.MsgID, "channel", msg.ChannelID, "payload", string(b)) - - return nil -} diff --git a/ws/structs.go b/ws/structs.go index bba61be..bb70593 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -3,6 +3,7 @@ package ws import ( "context" "encoding/json" + "homestead/homestead_gateway/util/cache" "log/slog" "sync" "time" @@ -17,48 +18,58 @@ type WebsocketGateway struct { upgrader websocket.Upgrader connsMu sync.Mutex - conns map[*websocket.Conn]connMetadata + conns map[*websocket.Conn]connectionMetaData + cache cache.Cache modHandler ModHandler botHandler BotHandler logger *slog.Logger - closefn func() error + closeFn func() error } -type User struct { +type MinecraftUser struct { ID string `json:"id"` Name string `json:"name"` } +type DiscordUser struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type Destination struct { + ChannelID string `json:"channel_id"` +} + // GatewayModMessageIn : Mod -> Gateway -> Bot type GatewayModMessageIn struct { - MsgID string `json:"msg_id"` - Server string `json:"server"` - User User `json:"user"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` - ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) + MsgID string `json:"msg_id"` + Server string `json:"server"` + Destination Destination `json:"destination"` + Author MinecraftUser `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) } // GatewayBotMessageIn : Bot -> Gateway -> Mod type GatewayBotMessageIn struct { MsgID string `json:"msg_id"` ChannelID string `json:"channel_id"` - Author string `json:"author"` + Author DiscordUser `json:"author"` Content string `json:"content"` Meta map[string]interface{} `json:"meta,omitempty"` Ts string `json:"ts,omitempty"` - ReceivedAt time.Time `json:"-"` + ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from bot) } -// ForwardedModMessage : Gateway -> Bot -type ForwardedModMessage struct { +// GatewayModMessageOut : Gateway -> Bot +type GatewayModMessageOut struct { Type string `json:"type"` // "mod" - ServerID string `json:"server_id"` ChannelID string `json:"channel_id"` - User User `json:"user"` + Author MinecraftUser `json:"author"` Content string `json:"content"` Meta map[string]interface{} `json:"meta,omitempty"` Ts string `json:"ts,omitempty"` @@ -66,12 +77,11 @@ type ForwardedModMessage struct { ForwardedAt time.Time `json:"forwarded_at"` } -// ForwardedBotMessage : Gateway -> Mod -type ForwardedBotMessage struct { +// GatewayBotMessageOut : Gateway -> Mod +type GatewayBotMessageOut struct { Type string `json:"type"` // "bot" - ServerID string `json:"server_id"` ChannelID string `json:"channel_id"` - Author string `json:"author"` + Author DiscordUser `json:"author"` Content string `json:"content"` Meta map[string]interface{} `json:"meta,omitempty"` Ts string `json:"ts,omitempty"` @@ -79,25 +89,16 @@ type ForwardedBotMessage struct { ForwardedAt time.Time `json:"forwarded_at"` } -type BotAck struct { - Type string `json:"type"` // "acknowledge" - MsgID string `json:"msg_id"` - Status string `json:"status,omitempty"` // "queued"/"sent" +type GatewayAck struct { + Status string `json:"status"` + Type string `json:"type"` } -type MessageEnvelope struct { +type Handshake struct { Type string `json:"type"` // "mod" or "bot" Data json.RawMessage `json:"data"` } -type ModHandler interface { - Handle(ctx context.Context, msg GatewayModMessageIn) error -} - -type BotHandler interface { - Handle(ctx context.Context, msg GatewayBotMessageIn) error -} - type ModHandshake struct { ServerID string `json:"server_id"` } @@ -106,7 +107,15 @@ type BotHandshake struct { BotID string `json:"bot_id"` } -type connMetadata struct { - connType string // "mod" or "bot" - id string // server_id or bot_id for logging +type ModHandler interface { + Handle(ctx context.Context, msg GatewayModMessageIn) error +} + +type BotHandler interface { + Handle(ctx context.Context, msg GatewayBotMessageIn) error +} + +type connectionMetaData struct { + connectionType string // "mod" or "bot" + id string // server_id or bot_id for logging } diff --git a/ws/util.go b/ws/util.go index aeeed09..4822779 100644 --- a/ws/util.go +++ b/ws/util.go @@ -2,32 +2,88 @@ package ws import ( "context" - "fmt" + "encoding/json" + "errors" "log/slog" "net/http" - "os/signal" - "syscall" "time" "github.com/gorilla/websocket" ) -func (g *WebsocketGateway) Start() error { - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer stop() +// (de-)register - return g.Serve(ctx, fmt.Sprintf(":%d", g.port)) +func (wsg *WebsocketGateway) listen(srv *http.Server, addr string, channel chan error) { + wsg.logger.Info("Gateway listening.", "addr", addr) + + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + channel <- err + } + + close(channel) +} + +func (wsg *WebsocketGateway) deafen(srv *http.Server) { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + wsg.logger.Info("Shutting down Websocket Gateway.") + + _ = srv.Shutdown(shutdownCtx) + wsg.closeAll() +} + +// responses + +func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code}) +} + +func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int) { + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code}) + _ = conn.Close() +} + +func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error { + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + + if err := conn.WriteJSON(content); err != nil { + wsg.logger.Error("Failed to respond to connection.", "remote", conn.RemoteAddr().String(), "err", err) + _ = conn.Close() + return err + } + + return nil } // util -func (g *WebsocketGateway) validateApiKey(r *http.Request) bool { +func (wsg *WebsocketGateway) validateAndUpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { + if !wsg.validateApiKey(r) { + wsg.sendHttpError(w, "Unauthorized", 401) + wsg.logger.Warn("Authorization failed", "remote", r.RemoteAddr) + return nil, errors.New("unauthorized") + } + + conn, err := wsg.upgrader.Upgrade(w, r, nil) + if err != nil { + wsg.logger.Error("Upgrade error.", "remote", r.RemoteAddr) + return nil, err + } + + return conn, err +} + +func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { apiKey := r.URL.Query().Get("api_key") if apiKey == "" { apiKey = r.Header.Get("X-API-Key") } - return !(apiKey == "" || apiKey != g.apiKey) + return !(apiKey == "" || apiKey != wsg.apiKey) } func writeJSONSafe(c *websocket.Conn, v interface{}) error { @@ -49,24 +105,26 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { // connections -func (g *WebsocketGateway) registerConn(c *websocket.Conn, meta connMetadata) { - g.connsMu.Lock() - g.conns[c] = meta - g.connsMu.Unlock() +func (wsg *WebsocketGateway) registerConn(c *websocket.Conn, meta connectionMetaData) { + wsg.connsMu.Lock() + wsg.conns[c] = meta + wsg.connsMu.Unlock() } -func (g *WebsocketGateway) unregisterConn(c *websocket.Conn) { - g.connsMu.Lock() - delete(g.conns, c) - g.connsMu.Unlock() +func (wsg *WebsocketGateway) unregisterConn(c *websocket.Conn) { + wsg.connsMu.Lock() + delete(wsg.conns, c) + wsg.connsMu.Unlock() } -func (g *WebsocketGateway) closeAll() { - g.connsMu.Lock() - defer g.connsMu.Unlock() - g.logger.Info("closing websocket connections") - for c := range g.conns { - _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"), time.Now().Add(time.Second)) +func (wsg *WebsocketGateway) closeAll() { + wsg.connsMu.Lock() + defer wsg.connsMu.Unlock() + + wsg.logger.Info("Closing all websocket connections.") + + for c := range wsg.conns { + _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) _ = c.Close() } } diff --git a/ws/websocket.go b/ws/websocket.go index 7cc062b..8c91a20 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -3,11 +3,13 @@ package ws import ( "context" "encoding/json" - "errors" + "fmt" "homestead/homestead_gateway/util/config" "log/slog" "net" "net/http" + "os/signal" + "syscall" "time" "github.com/gorilla/websocket" @@ -16,7 +18,7 @@ import ( func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { return &WebsocketGateway{ logger: logger, - closefn: closefn, + closeFn: closefn, apiKey: cfg.Websocket, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, @@ -25,7 +27,7 @@ func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() return true // local by default; change for production }, }, - conns: make(map[*websocket.Conn]connMetadata), + conns: make(map[*websocket.Conn]connectionMetaData), bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, port: cfg.HttpPort, modHandler: modH, @@ -33,58 +35,48 @@ func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() } } -func (g *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { +func (wsg *WebsocketGateway) Start() error { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + return wsg.Serve(ctx, fmt.Sprintf(":%d", wsg.port)) +} + +func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { mux := http.NewServeMux() - mux.HandleFunc("/ws", g.handleWS) - mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - _, _ = w.Write([]byte("ok")) - }) + mux.HandleFunc("/push", wsg.handlePush) + mux.HandleFunc("/ready", wsg.handleReady) + mux.HandleFunc("/health", wsg.handleHealth) + mux.HandleFunc("/register", wsg.handleRegister) srv := &http.Server{ Addr: listenAddr, - Handler: loggingMiddleware(g.logger, mux), + Handler: loggingMiddleware(wsg.logger, mux), BaseContext: func(l net.Listener) context.Context { return ctx }, } - errCh := make(chan error, 1) - go func() { - g.logger.Info("ws gateway listening", "addr", listenAddr) - if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errCh <- err - } - close(errCh) - }() + + go wsg.listen(srv, listenAddr, errCh) select { case <-ctx.Done(): - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - g.logger.Info("shutting down http server") - _ = srv.Shutdown(shutdownCtx) - g.closeAll() + wsg.deafen(srv) return nil case err := <-errCh: return err } } -func (g *WebsocketGateway) handleWS(w http.ResponseWriter, r *http.Request) { - if !g.validateApiKey(r) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - g.logger.Warn("ws auth failed", "remote", r.RemoteAddr) - return - } +// - conn, err := g.upgrader.Upgrade(w, r, nil) +func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) { + conn, err := wsg.validateAndUpgradeConnection(w, r) if err != nil { - g.logger.Error("ws upgrade error", err, "remote", r.RemoteAddr) return } - // Configure read limits & pong handler - if g.bodySizeBytes > 0 { - conn.SetReadLimit(g.bodySizeBytes) + if wsg.bodySizeBytes > 0 { + conn.SetReadLimit(wsg.bodySizeBytes) } else { conn.SetReadLimit(1 << 20) // sensible default 1MiB } @@ -95,72 +87,93 @@ func (g *WebsocketGateway) handleWS(w http.ResponseWriter, r *http.Request) { return nil }) - // First message must be a handshake identifying the connection type typ, data, err := conn.ReadMessage() if err != nil { - g.logger.Error("failed to read handshake", "err", err, "remote", conn.RemoteAddr().String()) - _ = conn.Close() + wsg.sendWebsocketError(conn, "Internal Server Error", 500) + wsg.logger.Error("Failed to read handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } if typ != websocket.TextMessage && typ != websocket.BinaryMessage { - g.logger.Warn("invalid handshake message type", "remote", conn.RemoteAddr().String()) - _ = writeJSONSafe(conn, map[string]string{"error": "first message must be handshake"}) - _ = conn.Close() + wsg.sendWebsocketError(conn, "First message must be a handshake.", 400) + wsg.logger.Warn("Invalid handshake message type.", "remote", conn.RemoteAddr().String()) return } - var envelope MessageEnvelope - if err := json.Unmarshal(data, &envelope); err != nil { - g.logger.Warn("invalid handshake json", "err", err, "remote", conn.RemoteAddr().String()) - _ = writeJSONSafe(conn, map[string]string{"error": "invalid handshake: " + err.Error()}) - _ = conn.Close() + var handshake Handshake + if err := json.Unmarshal(data, &handshake); err != nil { + wsg.sendWebsocketError(conn, "Malformed handshake.", 400) + wsg.logger.Warn("Malformed handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta := connMetadata{connType: envelope.Type} + meta := connectionMetaData{connectionType: handshake.Type} - // Validate handshake based on type - switch envelope.Type { + switch handshake.Type { case "mod": - var hs ModHandshake - if err := json.Unmarshal(envelope.Data, &hs); err != nil { - g.logger.Warn("invalid mod handshake", "err", err, "remote", conn.RemoteAddr().String()) - _ = writeJSONSafe(conn, map[string]string{"error": "invalid mod handshake: " + err.Error()}) - _ = conn.Close() + var mhs ModHandshake + + if err := json.Unmarshal(handshake.Data, &mhs); err != nil { + wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400) + wsg.logger.Warn("Malformed mod handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta.id = hs.ServerID - g.registerConn(conn, meta) - g.logger.Info("mod connected", "server_id", hs.ServerID, "remote", conn.RemoteAddr().String()) - go g.modReadLoop(conn, meta) + + meta.id = mhs.ServerID + + if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { + return + } + + wsg.registerConn(conn, meta) + wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) + + go wsg.modReadLoop(conn, meta) // replace with external handler mayhaps case "bot": - var hs BotHandshake - if err := json.Unmarshal(envelope.Data, &hs); err != nil { - g.logger.Warn("invalid bot handshake", "err", err, "remote", conn.RemoteAddr().String()) - _ = writeJSONSafe(conn, map[string]string{"error": "invalid bot handshake: " + err.Error()}) - _ = conn.Close() + var bhs BotHandshake + + if err := json.Unmarshal(handshake.Data, &bhs); err != nil { + wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400) + wsg.logger.Warn("Malformed bot handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta.id = hs.BotID - g.registerConn(conn, meta) - g.logger.Info("bot connected", "bot_id", hs.BotID, "remote", conn.RemoteAddr().String()) - go g.botReadLoop(conn, meta) + + meta.id = bhs.BotID + + if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { + return + } + + wsg.registerConn(conn, meta) + wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) + + go wsg.botReadLoop(conn, meta) // replace with external handler mayhaps default: - g.logger.Warn("unknown connection type", "type", envelope.Type, "remote", conn.RemoteAddr().String()) - _ = writeJSONSafe(conn, map[string]string{"error": "unknown connection type: " + envelope.Type}) - _ = conn.Close() + wsg.sendWebsocketError(conn, "Unknown handshake.", 400) + wsg.logger.Warn("Unknown connection type.", "remote", conn.RemoteAddr().String(), "type", handshake.Type) return } } -func (g *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connMetadata) { +func (wsg *WebsocketGateway) handleReady(w http.ResponseWriter, r *http.Request) {} + +func (wsg *WebsocketGateway) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "healthy"}) +} + +func (wsg *WebsocketGateway) handleRegister(w http.ResponseWriter, r *http.Request) {} + +// + +func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaData) { defer func() { - g.unregisterConn(c) + wsg.unregisterConn(c) _ = c.Close() - g.logger.Info("mod disconnected", "server_id", meta.id, "remote", c.RemoteAddr().String()) + wsg.logger.Info("mod disconnected", "server_id", meta.id, "remote", c.RemoteAddr().String()) }() pingTicker := time.NewTicker(30 * time.Second) @@ -170,9 +183,9 @@ func (g *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connMetadata) { typ, data, err := c.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - g.logger.Warn("unexpected mod close", "server_id", meta.id, "err", err) + wsg.logger.Warn("unexpected mod close", "server_id", meta.id, "err", err) } else { - g.logger.Debug("mod read error", "server_id", meta.id, "err", err) + wsg.logger.Debug("mod read error", "server_id", meta.id, "err", err) } return } @@ -184,21 +197,21 @@ func (g *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connMetadata) { var msg GatewayModMessageIn if err := json.Unmarshal(data, &msg); err != nil { _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - g.logger.Warn("invalid json from mod", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + wsg.logger.Warn("invalid json from mod", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) continue } msg.ReceivedAt = time.Now().UTC() if err := msg.Validate(); err != nil { _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - g.logger.Warn("mod message validation failed", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + wsg.logger.Warn("mod message validation failed", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) continue } // Handle the message (forward to bot, enrich, etc.) - if err := g.modHandler.Handle(context.Background(), msg); err != nil { + if err := wsg.modHandler.Handle(context.Background(), msg); err != nil { _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) - g.logger.Error("mod handler error", "server_id", meta.id, "err", err) + wsg.logger.Error("mod handler error", "server_id", meta.id, "err", err) continue } @@ -209,7 +222,7 @@ func (g *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connMetadata) { case <-pingTicker.C: _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - g.logger.Debug("write ping failed", "server_id", meta.id, "err", err) + wsg.logger.Debug("write ping failed", "server_id", meta.id, "err", err) return } default: @@ -217,11 +230,11 @@ func (g *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connMetadata) { } } -func (g *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connMetadata) { +func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaData) { defer func() { - g.unregisterConn(c) + wsg.unregisterConn(c) _ = c.Close() - g.logger.Info("bot disconnected", "bot_id", meta.id, "remote", c.RemoteAddr().String()) + wsg.logger.Info("bot disconnected", "bot_id", meta.id, "remote", c.RemoteAddr().String()) }() pingTicker := time.NewTicker(30 * time.Second) @@ -231,9 +244,9 @@ func (g *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connMetadata) { typ, data, err := c.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - g.logger.Warn("unexpected bot close", "bot_id", meta.id, "err", err) + wsg.logger.Warn("unexpected bot close", "bot_id", meta.id, "err", err) } else { - g.logger.Debug("bot read error", "bot_id", meta.id, "err", err) + wsg.logger.Debug("bot read error", "bot_id", meta.id, "err", err) } return } @@ -245,21 +258,21 @@ func (g *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connMetadata) { var msg GatewayBotMessageIn if err := json.Unmarshal(data, &msg); err != nil { _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - g.logger.Warn("invalid json from bot", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + wsg.logger.Warn("invalid json from bot", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) continue } msg.ReceivedAt = time.Now().UTC() if err := msg.Validate(); err != nil { _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - g.logger.Warn("bot message validation failed", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + wsg.logger.Warn("bot message validation failed", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) continue } // Handle the message (forward to mod, enrich, etc.) - if err := g.botHandler.Handle(context.Background(), msg); err != nil { + if err := wsg.botHandler.Handle(context.Background(), msg); err != nil { _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) - g.logger.Error("bot handler error", "bot_id", meta.id, "err", err) + wsg.logger.Error("bot handler error", "bot_id", meta.id, "err", err) continue } @@ -270,7 +283,7 @@ func (g *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connMetadata) { case <-pingTicker.C: _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - g.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) + wsg.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) return } default: From 925dc319e8a822e735a9ee5ade3ad654d2d84596 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 19:04:35 +0100 Subject: [PATCH 06/15] working Gateway for incoming Mod/Bot ws-conn's; simulated Mod client via sim.go --- sim.go | 314 ++++++++++++++++++++++++------------------------ ws/handlers.go | 106 ++++++++++++++++ ws/temp.go | 70 +++++++++++ ws/validate.go | 6 +- ws/websocket.go | 146 ++++------------------ 5 files changed, 359 insertions(+), 283 deletions(-) create mode 100644 ws/temp.go diff --git a/sim.go b/sim.go index a8257e6..2309c1a 100644 --- a/sim.go +++ b/sim.go @@ -2,15 +2,23 @@ package main import ( "encoding/json" - "fmt" "log" "net/url" + "os" + "os/signal" + "syscall" "time" "github.com/gorilla/websocket" ) -type MessageEnvelope struct { +const ( + gatewayURL = "ws://localhost:3333/push" + apiKey = "gateway" + serverID = "test-server-001" +) + +type Handshake struct { Type string `json:"type"` Data json.RawMessage `json:"data"` } @@ -19,173 +27,163 @@ type ModHandshake struct { ServerID string `json:"server_id"` } -type BotHandshake struct { - BotID string `json:"bot_id"` +type GatewayAck struct { + Status string `json:"status"` + Type string `json:"type"` } -type User struct { +type MinecraftUser struct { ID string `json:"id"` Name string `json:"name"` } -type ModMessage struct { - MsgID string `json:"msg_id"` - Server string `json:"server"` - User User `json:"user"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` +type Destination struct { + ChannelID string `json:"channel_id"` } -type BotMessage struct { - MsgID string `json:"msg_id"` - ChannelID string `json:"channel_id"` - Author string `json:"author"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` -} - -func simulateMod() { - u := url.URL{Scheme: "ws", Host: "localhost:3333", Path: "/ws"} - u.RawQuery = "api_key=gateway" - - conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) - if err != nil { - log.Fatalf("mod dial error: %v", err) - } - defer conn.Close() - - fmt.Println("[MOD] Connected") - - // Send handshake - handshake := ModHandshake{ServerID: "survival_server"} - hsData, _ := json.Marshal(handshake) - envelope := MessageEnvelope{ - Type: "mod", - Data: hsData, - } - envData, _ := json.Marshal(envelope) - - if err := conn.WriteMessage(websocket.TextMessage, envData); err != nil { - log.Fatalf("mod handshake write error: %v", err) - } - fmt.Println("[MOD] Sent handshake") - - // Read handshake response - _, resp, err := conn.ReadMessage() - if err != nil { - log.Fatalf("mod handshake response error: %v", err) - } - fmt.Printf("[MOD] Handshake response: %s\n", string(resp)) - - // Send a few messages - for i := 1; i <= 3; i++ { - msg := ModMessage{ - MsgID: fmt.Sprintf("mod_msg_%d", i), - Server: "survival_server", - User: User{ID: "player_123", Name: "Steve"}, - Content: fmt.Sprintf("Message %d from Minecraft!", i), - Meta: map[string]interface{}{ - "coordinates": map[string]int{"x": 100 + i*10, "y": 64, "z": 200}, - }, - Ts: time.Now().UTC().Format(time.RFC3339), - } - - msgData, _ := json.Marshal(msg) - if err := conn.WriteMessage(websocket.TextMessage, msgData); err != nil { - log.Fatalf("mod message write error: %v", err) - } - fmt.Printf("[MOD] Sent message: %s\n", msg.Content) - - // Read response - _, resp, err := conn.ReadMessage() - if err != nil { - log.Fatalf("mod response error: %v", err) - } - fmt.Printf("[MOD] Response: %s\n", string(resp)) - - time.Sleep(1 * time.Second) - } - - fmt.Println("[MOD] Closing connection") -} - -func simulateBot() { - time.Sleep(2 * time.Second) // Let mod connect first - - u := url.URL{Scheme: "ws", Host: "localhost:3333", Path: "/ws"} - u.RawQuery = "api_key=gateway" - - conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) - if err != nil { - log.Fatalf("bot dial error: %v", err) - } - defer conn.Close() - - fmt.Println("[BOT] Connected") - - // Send handshake - handshake := BotHandshake{BotID: "discord_bot_1"} - hsData, _ := json.Marshal(handshake) - envelope := MessageEnvelope{ - Type: "bot", - Data: hsData, - } - envData, _ := json.Marshal(envelope) - - if err := conn.WriteMessage(websocket.TextMessage, envData); err != nil { - log.Fatalf("bot handshake write error: %v", err) - } - fmt.Println("[BOT] Sent handshake") - - // Read handshake response - _, resp, err := conn.ReadMessage() - if err != nil { - log.Fatalf("bot handshake response error: %v", err) - } - fmt.Printf("[BOT] Handshake response: %s\n", string(resp)) - - // Send a few messages - for i := 1; i <= 3; i++ { - msg := BotMessage{ - MsgID: fmt.Sprintf("bot_msg_%d", i), - ChannelID: "987654321", - Author: "DiscordUser#1234", - Content: fmt.Sprintf("Message %d from Discord!", i), - Meta: map[string]interface{}{ - "reactions": []string{"👍", "❤️"}, - }, - Ts: time.Now().UTC().Format(time.RFC3339), - } - - msgData, _ := json.Marshal(msg) - if err := conn.WriteMessage(websocket.TextMessage, msgData); err != nil { - log.Fatalf("bot message write error: %v", err) - } - fmt.Printf("[BOT] Sent message: %s\n", msg.Content) - - // Read response - _, resp, err := conn.ReadMessage() - if err != nil { - log.Fatalf("bot response error: %v", err) - } - fmt.Printf("[BOT] Response: %s\n", string(resp)) - - time.Sleep(1 * time.Second) - } - - fmt.Println("[BOT] Closing connection") +type GatewayModMessageIn struct { + MsgID string `json:"msg_id"` + Server string `json:"server"` + Destination Destination `json:"destination"` + Author MinecraftUser `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts string `json:"ts,omitempty"` } func main() { - fmt.Println("Starting WebSocket client simulator...") - fmt.Println("Connecting to ws://localhost:3333/ws with api_key=gateway") + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) - go simulateMod() - go simulateBot() + // Build WebSocket URL with API key + u, err := url.Parse(gatewayURL) + if err != nil { + log.Fatalf("Failed to parse URL: %v", err) + } + q := u.Query() + q.Set("api_key", apiKey) + u.RawQuery = q.Encode() - // Let them run - time.Sleep(15 * time.Second) - fmt.Println("Simulator finished") + log.Printf("Connecting to %s", u.String()) + + // Connect to WebSocket + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + log.Println("Connected to gateway") + + // Set up ping handler - respond to pings from server + conn.SetPingHandler(func(appData string) error { + log.Println("Received ping from server, sending pong") + err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) + if err != nil { + log.Printf("Failed to send pong: %v", err) + return err + } + return nil + }) + + // Send handshake + modHS := ModHandshake{ServerID: serverID} + modHSData, err := json.Marshal(modHS) + if err != nil { + log.Fatalf("Failed to marshal mod handshake: %v", err) + } + + handshake := Handshake{ + Type: "mod", + Data: modHSData, + } + + if err := conn.WriteJSON(handshake); err != nil { + log.Fatalf("Failed to send handshake: %v", err) + } + + log.Println("Handshake sent") + + // Read acknowledgment + var ack GatewayAck + if err := conn.ReadJSON(&ack); err != nil { + log.Fatalf("Failed to read acknowledgment: %v", err) + } + + log.Printf("Received acknowledgment: status=%s, type=%s", ack.Status, ack.Type) + + // Channel for incoming messages + done := make(chan struct{}) + + // Read loop - handles incoming messages and processes control frames + go func() { + defer close(done) + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + log.Printf("WebSocket error: %v", err) + } else { + log.Printf("Connection closed: %v", err) + } + return + } + + // Only log text/binary messages (ping/pong handled by handlers) + if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { + log.Printf("Received from server: %s", string(message)) + } + } + }() + + // Optional: Send a test message after connecting + time.Sleep(2 * time.Second) + testMsg := GatewayModMessageIn{ + MsgID: "test-msg-001", + Server: serverID, + Destination: Destination{ + ChannelID: "123456789", + }, + Author: MinecraftUser{ + ID: "player-uuid-123", + Name: "TestPlayer", + }, + Content: "Hello from simulated mod!", + Ts: time.Now().UTC().Format(time.RFC3339), + } + + if err := conn.WriteJSON(testMsg); err != nil { + log.Printf("Failed to send test message: %v", err) + } else { + log.Println("Sent test message to gateway") + } + + log.Println("Connection established. Responding to pings. Press Ctrl+C to disconnect.") + + // Wait for interrupt or connection close + for { + select { + case <-done: + log.Println("Connection closed") + return + case <-interrupt: + log.Println("Interrupt received, closing connection...") + + // Send close message + err := conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) + if err != nil { + log.Printf("Write close error: %v", err) + return + } + + select { + case <-done: + case <-time.After(time.Second): + } + return + } + } } diff --git a/ws/handlers.go b/ws/handlers.go index 9859295..c387568 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -1 +1,107 @@ package ws + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) { + conn, err := wsg.validateAndUpgradeConnection(w, r) + if err != nil { + return + } + + if wsg.bodySizeBytes > 0 { + conn.SetReadLimit(wsg.bodySizeBytes) + } else { + conn.SetReadLimit(1 << 20) // sensible default 1MiB + } + + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + typ, data, err := conn.ReadMessage() + if err != nil { + wsg.sendWebsocketError(conn, "Internal Server Error", 500) + wsg.logger.Error("Failed to read handshake.", "remote", conn.RemoteAddr().String(), "err", err) + return + } + + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { + wsg.sendWebsocketError(conn, "First message must be a handshake.", 400) + wsg.logger.Warn("Invalid handshake message type.", "remote", conn.RemoteAddr().String()) + return + } + + var handshake Handshake + if err := json.Unmarshal(data, &handshake); err != nil { + wsg.sendWebsocketError(conn, "Malformed handshake.", 400) + wsg.logger.Warn("Malformed handshake.", "remote", conn.RemoteAddr().String(), "err", err) + return + } + + meta := connectionMetaData{connectionType: handshake.Type} + + switch handshake.Type { + case "mod": + var mhs ModHandshake + + if err := json.Unmarshal(handshake.Data, &mhs); err != nil { + wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400) + wsg.logger.Warn("Malformed mod handshake.", "remote", conn.RemoteAddr().String(), "err", err) + return + } + + meta.id = mhs.ServerID + + if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { + return + } + + wsg.registerConn(conn, meta) + wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) + + go wsg.modReadLoop(conn, meta) // replace with external handler mayhaps + + case "bot": + var bhs BotHandshake + + if err := json.Unmarshal(handshake.Data, &bhs); err != nil { + wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400) + wsg.logger.Warn("Malformed bot handshake.", "remote", conn.RemoteAddr().String(), "err", err) + return + } + + meta.id = bhs.BotID + + if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { + return + } + + wsg.registerConn(conn, meta) + wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) + + go wsg.botReadLoop(conn, meta) // replace with external handler mayhaps + + default: + wsg.sendWebsocketError(conn, "Unknown handshake.", 400) + wsg.logger.Warn("Unknown connection type.", "remote", conn.RemoteAddr().String(), "type", handshake.Type) + return + } +} + +func (wsg *WebsocketGateway) handleReady(w http.ResponseWriter, r *http.Request) {} + +func (wsg *WebsocketGateway) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "healthy"}) +} + +func (wsg *WebsocketGateway) handleRegister(w http.ResponseWriter, r *http.Request) {} diff --git a/ws/temp.go b/ws/temp.go new file mode 100644 index 0000000..d40f5c4 --- /dev/null +++ b/ws/temp.go @@ -0,0 +1,70 @@ +package ws + +import ( + "context" + "encoding/json" + "log/slog" + "time" +) + +type LoggingModHandler struct { + logger *slog.Logger +} + +type LoggingBotHandler struct { + logger *slog.Logger +} + +func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { + return &LoggingModHandler{logger: logger} +} + +func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { + return &LoggingBotHandler{logger: logger} +} + +func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) error { + // For now, just log and pretend it's being forwarded + // TODO: Look up channel_id from database using server + // TODO: Forward to bot connection(s) + + fwd := GatewayModMessageOut{ + Type: "mod", + ChannelID: "TODO", // will come from database lookup + Author: msg.Author, + Content: msg.Content, + Meta: msg.Meta, + Ts: msg.Ts, + ReceivedAt: msg.ReceivedAt, + ForwardedAt: time.Now().UTC(), + } + + b, _ := json.Marshal(fwd) + h.logger.Info("received mod message", "msg_id", msg.MsgID, "server", msg.Server, "Author", msg.Author.Name, "content", msg.Content) + h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) + + return nil +} + +func (h *LoggingBotHandler) Handle(ctx context.Context, msg GatewayBotMessageIn) error { + // For now, just log and pretend it's being forwarded + // TODO: Look up server_id from database using channel_id + // TODO: Forward to mod connection(s) + + fwd := GatewayBotMessageOut{ + Type: "bot", + ChannelID: msg.ChannelID, + Author: msg.Author, + Content: msg.Content, + Meta: msg.Meta, + Ts: msg.Ts, + ReceivedAt: msg.ReceivedAt, + ForwardedAt: time.Now().UTC(), + } + + b, _ := json.Marshal(fwd) + h.logger.Info("received bot message", "msg_id", msg.MsgID, "channel", msg.ChannelID, "author", msg.Author, "content", msg.Content) + h.logger.Debug("forwarding bot message", "msg_id", msg.MsgID, "channel", msg.ChannelID, "payload", string(b)) + + return nil +} diff --git a/ws/validate.go b/ws/validate.go index f8feff6..6e494b3 100644 --- a/ws/validate.go +++ b/ws/validate.go @@ -12,8 +12,8 @@ func (m *GatewayModMessageIn) Validate() error { if strings.TrimSpace(m.Server) == "" { return errors.New("server missing") } - if strings.TrimSpace(m.User.ID) == "" { - return errors.New("user.id missing") + if strings.TrimSpace(m.Author.ID) == "" { + return errors.New("author.id missing") } if strings.TrimSpace(m.Content) == "" { return errors.New("content missing") @@ -28,7 +28,7 @@ func (m *GatewayBotMessageIn) Validate() error { if strings.TrimSpace(m.ChannelID) == "" { return errors.New("channel_id missing") } - if strings.TrimSpace(m.Author) == "" { + if strings.TrimSpace(m.Author.ID) == "" { return errors.New("author missing") } if strings.TrimSpace(m.Content) == "" { diff --git a/ws/websocket.go b/ws/websocket.go index 8c91a20..4422c4b 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -69,106 +69,6 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // -func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) { - conn, err := wsg.validateAndUpgradeConnection(w, r) - if err != nil { - return - } - - if wsg.bodySizeBytes > 0 { - conn.SetReadLimit(wsg.bodySizeBytes) - } else { - conn.SetReadLimit(1 << 20) // sensible default 1MiB - } - - _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - conn.SetPongHandler(func(appData string) error { - _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - return nil - }) - - typ, data, err := conn.ReadMessage() - if err != nil { - wsg.sendWebsocketError(conn, "Internal Server Error", 500) - wsg.logger.Error("Failed to read handshake.", "remote", conn.RemoteAddr().String(), "err", err) - return - } - - if typ != websocket.TextMessage && typ != websocket.BinaryMessage { - wsg.sendWebsocketError(conn, "First message must be a handshake.", 400) - wsg.logger.Warn("Invalid handshake message type.", "remote", conn.RemoteAddr().String()) - return - } - - var handshake Handshake - if err := json.Unmarshal(data, &handshake); err != nil { - wsg.sendWebsocketError(conn, "Malformed handshake.", 400) - wsg.logger.Warn("Malformed handshake.", "remote", conn.RemoteAddr().String(), "err", err) - return - } - - meta := connectionMetaData{connectionType: handshake.Type} - - switch handshake.Type { - case "mod": - var mhs ModHandshake - - if err := json.Unmarshal(handshake.Data, &mhs); err != nil { - wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400) - wsg.logger.Warn("Malformed mod handshake.", "remote", conn.RemoteAddr().String(), "err", err) - return - } - - meta.id = mhs.ServerID - - if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { - return - } - - wsg.registerConn(conn, meta) - wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) - - go wsg.modReadLoop(conn, meta) // replace with external handler mayhaps - - case "bot": - var bhs BotHandshake - - if err := json.Unmarshal(handshake.Data, &bhs); err != nil { - wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400) - wsg.logger.Warn("Malformed bot handshake.", "remote", conn.RemoteAddr().String(), "err", err) - return - } - - meta.id = bhs.BotID - - if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { - return - } - - wsg.registerConn(conn, meta) - wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) - - go wsg.botReadLoop(conn, meta) // replace with external handler mayhaps - - default: - wsg.sendWebsocketError(conn, "Unknown handshake.", 400) - wsg.logger.Warn("Unknown connection type.", "remote", conn.RemoteAddr().String(), "type", handshake.Type) - return - } -} - -func (wsg *WebsocketGateway) handleReady(w http.ResponseWriter, r *http.Request) {} - -func (wsg *WebsocketGateway) handleHealth(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(200) - _ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "healthy"}) -} - -func (wsg *WebsocketGateway) handleRegister(w http.ResponseWriter, r *http.Request) {} - -// - func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaData) { defer func() { wsg.unregisterConn(c) @@ -179,6 +79,18 @@ func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaD pingTicker := time.NewTicker(30 * time.Second) defer pingTicker.Stop() + // Send pings in a separate goroutine + go func() { + for range pingTicker.C { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { + wsg.logger.Debug("write ping failed", "server_id", meta.id, "err", err) + return + } + wsg.logger.Debug("sent ping to mod", "server_id", meta.id) + } + }() + for { typ, data, err := c.ReadMessage() if err != nil { @@ -216,17 +128,6 @@ func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaD } _ = writeJSONSafe(c, map[string]string{"status": "ok"}) - - // Handle pings - select { - case <-pingTicker.C: - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "server_id", meta.id, "err", err) - return - } - default: - } } } @@ -240,6 +141,18 @@ func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaD pingTicker := time.NewTicker(30 * time.Second) defer pingTicker.Stop() + // Send pings in a separate goroutine + go func() { + for range pingTicker.C { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { + wsg.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) + return + } + wsg.logger.Debug("sent ping to bot", "bot_id", meta.id) + } + }() + for { typ, data, err := c.ReadMessage() if err != nil { @@ -277,16 +190,5 @@ func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaD } _ = writeJSONSafe(c, map[string]string{"status": "ok"}) - - // Handle pings - select { - case <-pingTicker.C: - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) - return - } - default: - } } } From 2372da942a05200195e85fe23952d8ccd2bd630b Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 21:12:46 +0100 Subject: [PATCH 07/15] updated cache --- util/cache/cache.go | 99 ++++++++++++++++++++++++++----------------- util/cache/structs.go | 14 +++--- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/util/cache/cache.go b/util/cache/cache.go index 9af7ca3..57d682c 100644 --- a/util/cache/cache.go +++ b/util/cache/cache.go @@ -1,70 +1,91 @@ package cache -import "sync" - -type Cache struct { - mu sync.RWMutex - s2c map[string]string - c2s map[string]string -} - func NewCache() *Cache { - return &Cache{ + c := &Cache{} + c.state.Store(&mappings{ s2c: make(map[string]string), c2s: make(map[string]string), - } + }) + + return c } -// Set creates or overwrites the pair a -> b and b -> a. -// It ensures any previous mappings involving a or b are removed first. func (c *Cache) Set(serverId, channelId string) { c.mu.Lock() defer c.mu.Unlock() - if old, ok := c.s2c[serverId]; ok && old != channelId { - delete(c.c2s, old) + current := c.state.Load() + next := current.clone() + + if oldCh, ok := next.s2c[serverId]; ok && oldCh != channelId { + delete(next.c2s, oldCh) + } + if oldSrv, ok := next.c2s[channelId]; ok && oldSrv != serverId { + delete(next.s2c, oldSrv) } - if old, ok := c.c2s[channelId]; ok && old != serverId { - delete(c.s2c, old) - } + next.s2c[serverId] = channelId + next.c2s[channelId] = serverId - c.s2c[serverId] = channelId - c.c2s[channelId] = serverId + c.state.Store(next) } func (c *Cache) GetByServerId(serverId string) (string, bool) { - c.mu.RLock() - defer c.mu.RUnlock() - - cId, ok := c.s2c[serverId] - return cId, ok + m := c.state.Load() + val, ok := m.s2c[serverId] + return val, ok } func (c *Cache) GetByChannelId(channelId string) (string, bool) { - c.mu.RLock() - defer c.mu.RUnlock() - - sId, ok := c.c2s[channelId] - return sId, ok + m := c.state.Load() + val, ok := m.c2s[channelId] + return val, ok } func (c *Cache) RemoveByServerId(serverId string) { - c.mu.RLock() - defer c.mu.RUnlock() + c.mu.Lock() + defer c.mu.Unlock() - if channelId, ok := c.s2c[serverId]; ok { - delete(c.s2c, serverId) - delete(c.c2s, channelId) + current := c.state.Load() + if _, ok := current.s2c[serverId]; !ok { + return } + + next := current.clone() + if channelId, ok := next.s2c[serverId]; ok { + delete(next.s2c, serverId) + delete(next.c2s, channelId) + } + c.state.Store(next) } func (c *Cache) RemoveByChannelId(channelId string) { - c.mu.RLock() - defer c.mu.RUnlock() + c.mu.Lock() + defer c.mu.Unlock() - if serverId, ok := c.s2c[channelId]; ok { - delete(c.c2s, channelId) - delete(c.s2c, serverId) + current := c.state.Load() + if _, ok := current.c2s[channelId]; !ok { + return } + + next := current.clone() + if serverId, ok := next.c2s[channelId]; ok { + delete(next.c2s, channelId) + delete(next.s2c, serverId) + } + c.state.Store(next) +} + +func (m *mappings) clone() *mappings { + newM := &mappings{ + s2c: make(map[string]string, len(m.s2c)), + c2s: make(map[string]string, len(m.c2s)), + } + for k, v := range m.s2c { + newM.s2c[k] = v + } + for k, v := range m.c2s { + newM.c2s[k] = v + } + return newM } diff --git a/util/cache/structs.go b/util/cache/structs.go index b5ce12f..e484c32 100644 --- a/util/cache/structs.go +++ b/util/cache/structs.go @@ -2,15 +2,15 @@ package cache import ( "sync" + "sync/atomic" ) -type ShardedCache struct { - shards []*shard - shardMask uint32 +type Cache struct { + mu sync.Mutex + state atomic.Pointer[mappings] } -type shard struct { - mu sync.RWMutex - s2c map[string]string // serverId -> channelId - c2s map[string]string // channelId -> serverId +type mappings struct { + s2c map[string]string + c2s map[string]string } From 8f7db6256bc0a149e22885de6ce089d393a22748 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 1 Dec 2025 22:30:36 +0100 Subject: [PATCH 08/15] eod commit --- controller/controller.go | 2 +- main.go | 6 +++ util/cache/conn_cache.go | 88 +++++++++++++++++++++++++++++++++++ util/cache/structs.go | 21 +++++++++ ws/handlers.go | 7 +-- ws/structs.go | 18 +++----- ws/temp.go | 7 +-- ws/util.go | 26 +++++------ ws/websocket.go | 99 +++++++++++++++++++--------------------- 9 files changed, 190 insertions(+), 84 deletions(-) create mode 100644 util/cache/conn_cache.go diff --git a/controller/controller.go b/controller/controller.go index 285aa66..86873b4 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -16,7 +16,7 @@ func NewGatewayController(cfg config.Config) GatewayController { botHandler := ws.NewLoggingBotHandler(wsl) return GatewayController{ - Websocket: ws.NewWsGateway(cfg.Gateway, wsl, wCloseFn, modHandler, botHandler), + Websocket: ws.NewWebsocketGateway(cfg.Gateway, wsl, wCloseFn, modHandler, botHandler), HttpServer: HttpGateway{}, } } diff --git a/main.go b/main.go index fa5670f..f868478 100644 --- a/main.go +++ b/main.go @@ -19,3 +19,9 @@ func main() { panic(err) } } + +/** TODO + +- queue for messages, both ways (Ack "queued" instead of "completed"), filled queue drops oldest entry + +*/ diff --git a/util/cache/conn_cache.go b/util/cache/conn_cache.go new file mode 100644 index 0000000..a49a6fe --- /dev/null +++ b/util/cache/conn_cache.go @@ -0,0 +1,88 @@ +package cache + +import ( + "github.com/gorilla/websocket" +) + +func NewConnectionCache() *ConnectionCache { + c := &ConnectionCache{} + c.state.Store(&connectionMappings{ + id2c: make(map[string]*conn), + }) + + return c +} + +func (c *ConnectionCache) Set(id string, connection *websocket.Conn, meta *ConnectionMetaData) { + c.mu.Lock() + defer c.mu.Unlock() + + current := c.state.Load() + next := current.clone() + + next.id2c[id] = &conn{ + connection: connection, + meta: meta, + } + + c.state.Store(next) +} + +func (c *ConnectionCache) GetById(id string) (*websocket.Conn, *ConnectionMetaData, bool) { + m := c.state.Load() + if conn, ok := m.id2c[id]; ok { + return conn.connection, conn.meta, true + } + + return nil, nil, false +} + +func (c *ConnectionCache) RemoveById(id string) { + c.mu.Lock() + defer c.mu.Unlock() + + current := c.state.Load() + if _, ok := current.id2c[id]; !ok { + return + } + + next := current.clone() + delete(next.id2c, id) + + c.state.Store(next) +} + +func (c *ConnectionCache) Range(fn func(id string, connection *websocket.Conn, meta *ConnectionMetaData)) { + m := c.state.Load() + for id, conn := range m.id2c { + fn(id, conn.connection, conn.meta) + } +} + +func (c *ConnectionCache) GetAllConnections() []*websocket.Conn { + m := c.state.Load() + conns := make([]*websocket.Conn, 0, len(m.id2c)) + for _, conn := range m.id2c { + conns = append(conns, conn.connection) + } + return conns +} + +func (c *ConnectionCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.state.Store(&connectionMappings{ + id2c: make(map[string]*conn), + }) +} + +func (cm *connectionMappings) clone() *connectionMappings { + newCM := &connectionMappings{ + id2c: make(map[string]*conn, len(cm.id2c)), + } + for k, v := range cm.id2c { + newCM.id2c[k] = v + } + return newCM +} diff --git a/util/cache/structs.go b/util/cache/structs.go index e484c32..37aac37 100644 --- a/util/cache/structs.go +++ b/util/cache/structs.go @@ -3,6 +3,8 @@ package cache import ( "sync" "sync/atomic" + + "github.com/gorilla/websocket" ) type Cache struct { @@ -10,7 +12,26 @@ type Cache struct { state atomic.Pointer[mappings] } +type ConnectionCache struct { + mu sync.Mutex + state atomic.Pointer[connectionMappings] +} + type mappings struct { s2c map[string]string c2s map[string]string } + +type connectionMappings struct { + id2c map[string]*conn +} + +type ConnectionMetaData struct { + ConnectionType string // "mod" or "bot" + ID string // server_id or bot_id for logging +} + +type conn struct { + connection *websocket.Conn + meta *ConnectionMetaData +} diff --git a/ws/handlers.go b/ws/handlers.go index c387568..5b0c696 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -2,6 +2,7 @@ package ws import ( "encoding/json" + "homestead/homestead_gateway/util/cache" "net/http" "time" @@ -46,7 +47,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - meta := connectionMetaData{connectionType: handshake.Type} + meta := cache.ConnectionMetaData{ConnectionType: handshake.Type} switch handshake.Type { case "mod": @@ -58,7 +59,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - meta.id = mhs.ServerID + meta.ID = mhs.ServerID if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { return @@ -78,7 +79,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - meta.id = bhs.BotID + meta.ID = bhs.BotID if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { return diff --git a/ws/structs.go b/ws/structs.go index bb70593..1bbe67e 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -1,11 +1,9 @@ package ws import ( - "context" "encoding/json" "homestead/homestead_gateway/util/cache" "log/slog" - "sync" "time" "github.com/gorilla/websocket" @@ -17,9 +15,10 @@ type WebsocketGateway struct { bodySizeBytes int64 upgrader websocket.Upgrader - connsMu sync.Mutex - conns map[*websocket.Conn]connectionMetaData - cache cache.Cache + + cache *cache.Cache + botConn *websocket.Conn + conns *cache.ConnectionCache modHandler ModHandler botHandler BotHandler @@ -108,14 +107,9 @@ type BotHandshake struct { } type ModHandler interface { - Handle(ctx context.Context, msg GatewayModMessageIn) error + Handle(conn *websocket.Conn, msg GatewayModMessageIn) error } type BotHandler interface { - Handle(ctx context.Context, msg GatewayBotMessageIn) error -} - -type connectionMetaData struct { - connectionType string // "mod" or "bot" - id string // server_id or bot_id for logging + Handle(conn *websocket.Conn, msg GatewayBotMessageIn) error } diff --git a/ws/temp.go b/ws/temp.go index d40f5c4..c62057b 100644 --- a/ws/temp.go +++ b/ws/temp.go @@ -1,10 +1,11 @@ package ws import ( - "context" "encoding/json" "log/slog" "time" + + "github.com/gorilla/websocket" ) type LoggingModHandler struct { @@ -23,7 +24,7 @@ func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { return &LoggingBotHandler{logger: logger} } -func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) error { +func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayModMessageIn) error { // For now, just log and pretend it's being forwarded // TODO: Look up channel_id from database using server // TODO: Forward to bot connection(s) @@ -46,7 +47,7 @@ func (h *LoggingModHandler) Handle(ctx context.Context, msg GatewayModMessageIn) return nil } -func (h *LoggingBotHandler) Handle(ctx context.Context, msg GatewayBotMessageIn) error { +func (h *LoggingBotHandler) Handle(conn *websocket.Conn, msg GatewayBotMessageIn) error { // For now, just log and pretend it's being forwarded // TODO: Look up server_id from database using channel_id // TODO: Forward to mod connection(s) diff --git a/ws/util.go b/ws/util.go index 4822779..f695de3 100644 --- a/ws/util.go +++ b/ws/util.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "homestead/homestead_gateway/util/cache" "log/slog" "net/http" "time" @@ -41,6 +42,11 @@ func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string _ = json.NewEncoder(w).Encode(map[string]interface{}{"message": message, "code": code}) } +func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) { + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _ = conn.WriteMessage(websocket.PingMessage, nil) +} + func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int) { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code}) @@ -105,26 +111,20 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { // connections -func (wsg *WebsocketGateway) registerConn(c *websocket.Conn, meta connectionMetaData) { - wsg.connsMu.Lock() - wsg.conns[c] = meta - wsg.connsMu.Unlock() +func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData) { + wsg.conns.Set(meta.ID, conn, &meta) } -func (wsg *WebsocketGateway) unregisterConn(c *websocket.Conn) { - wsg.connsMu.Lock() - delete(wsg.conns, c) - wsg.connsMu.Unlock() +func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData) { + wsg.conns.RemoveById(meta.ID) + _ = conn.Close() } func (wsg *WebsocketGateway) closeAll() { - wsg.connsMu.Lock() - defer wsg.connsMu.Unlock() - wsg.logger.Info("Closing all websocket connections.") - for c := range wsg.conns { + wsg.conns.Range(func(id string, c *websocket.Conn, meta *cache.ConnectionMetaData) { _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) _ = c.Close() - } + }) } diff --git a/ws/websocket.go b/ws/websocket.go index 4422c4b..a5ba800 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "homestead/homestead_gateway/util/cache" "homestead/homestead_gateway/util/config" "log/slog" "net" @@ -15,11 +16,16 @@ import ( "github.com/gorilla/websocket" ) -func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { +func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - apiKey: cfg.Websocket, + logger: logger, + closeFn: closefn, + modHandler: modH, + botHandler: botH, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + cache: cache.NewCache(), + conns: cache.NewConnectionCache(), upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -27,11 +33,7 @@ func NewWsGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() return true // local by default; change for production }, }, - conns: make(map[*websocket.Conn]connectionMetaData), bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, - port: cfg.HttpPort, - modHandler: modH, - botHandler: botH, } } @@ -69,35 +71,29 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // -func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaData) { +func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { defer func() { - wsg.unregisterConn(c) - _ = c.Close() - wsg.logger.Info("mod disconnected", "server_id", meta.id, "remote", c.RemoteAddr().String()) + wsg.unregisterConn(conn, meta) + wsg.logger.Info("Mod-Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) }() - pingTicker := time.NewTicker(30 * time.Second) - defer pingTicker.Stop() + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() - // Send pings in a separate goroutine go func() { - for range pingTicker.C { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "server_id", meta.id, "err", err) - return - } - wsg.logger.Debug("sent ping to mod", "server_id", meta.id) + for range ticker.C { + wsg.sendWebsocketPing(conn) } }() for { - typ, data, err := c.ReadMessage() + typ, data, err := conn.ReadMessage() + if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("unexpected mod close", "server_id", meta.id, "err", err) + wsg.logger.Warn("unexpected mod close", "server_id", meta.ID, "err", err) } else { - wsg.logger.Debug("mod read error", "server_id", meta.id, "err", err) + wsg.logger.Debug("mod read error", "server_id", meta.ID, "err", err) } return } @@ -108,34 +104,33 @@ func (wsg *WebsocketGateway) modReadLoop(c *websocket.Conn, meta connectionMetaD var msg GatewayModMessageIn if err := json.Unmarshal(data, &msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - wsg.logger.Warn("invalid json from mod", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) + wsg.logger.Warn("invalid json from mod", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } msg.ReceivedAt = time.Now().UTC() if err := msg.Validate(); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - wsg.logger.Warn("mod message validation failed", "server_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) + wsg.logger.Warn("mod message validation failed", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } // Handle the message (forward to bot, enrich, etc.) - if err := wsg.modHandler.Handle(context.Background(), msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) - wsg.logger.Error("mod handler error", "server_id", meta.id, "err", err) + if err := wsg.modHandler.Handle(conn, msg); err != nil { + _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) + wsg.logger.Error("mod handler error", "server_id", meta.ID, "err", err) continue } - _ = writeJSONSafe(c, map[string]string{"status": "ok"}) + _ = writeJSONSafe(conn, map[string]string{"status": "completed"}) // or "queued" } } -func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaData) { +func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { defer func() { - wsg.unregisterConn(c) - _ = c.Close() - wsg.logger.Info("bot disconnected", "bot_id", meta.id, "remote", c.RemoteAddr().String()) + wsg.unregisterConn(conn, meta) + wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) }() pingTicker := time.NewTicker(30 * time.Second) @@ -144,22 +139,22 @@ func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaD // Send pings in a separate goroutine go func() { for range pingTicker.C { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "bot_id", meta.id, "err", err) + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + wsg.logger.Debug("write ping failed", "bot_id", meta.ID, "err", err) return } - wsg.logger.Debug("sent ping to bot", "bot_id", meta.id) + wsg.logger.Debug("sent ping to bot", "bot_id", meta.ID) } }() for { - typ, data, err := c.ReadMessage() + typ, data, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("unexpected bot close", "bot_id", meta.id, "err", err) + wsg.logger.Warn("unexpected bot close", "bot_id", meta.ID, "err", err) } else { - wsg.logger.Debug("bot read error", "bot_id", meta.id, "err", err) + wsg.logger.Debug("bot read error", "bot_id", meta.ID, "err", err) } return } @@ -170,25 +165,25 @@ func (wsg *WebsocketGateway) botReadLoop(c *websocket.Conn, meta connectionMetaD var msg GatewayBotMessageIn if err := json.Unmarshal(data, &msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "invalid json: " + err.Error()}) - wsg.logger.Warn("invalid json from bot", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) + wsg.logger.Warn("invalid json from bot", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } msg.ReceivedAt = time.Now().UTC() if err := msg.Validate(); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": err.Error()}) - wsg.logger.Warn("bot message validation failed", "bot_id", meta.id, "remote", c.RemoteAddr().String(), "err", err) + _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) + wsg.logger.Warn("bot message validation failed", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) continue } // Handle the message (forward to mod, enrich, etc.) - if err := wsg.botHandler.Handle(context.Background(), msg); err != nil { - _ = writeJSONSafe(c, map[string]string{"error": "handler error: " + err.Error()}) - wsg.logger.Error("bot handler error", "bot_id", meta.id, "err", err) + if err := wsg.botHandler.Handle(conn, msg); err != nil { + _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) + wsg.logger.Error("bot handler error", "bot_id", meta.ID, "err", err) continue } - _ = writeJSONSafe(c, map[string]string{"status": "ok"}) + _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) } } From 7a30adc2e857db39e8dd1ebd8e870cc9e2bf7e64 Mon Sep 17 00:00:00 2001 From: Overlord Date: Tue, 2 Dec 2025 10:34:45 +0100 Subject: [PATCH 09/15] temp push --- main.go | 2 -- sim.go | 4 +++- ws/handlers.go | 14 +++++++------- ws/structs.go | 6 +++--- ws/temp.go | 7 +++++++ ws/util.go | 29 ++++++++++++++++++++++++++--- ws/websocket.go | 36 +++++++++++++++++++++++++++--------- 7 files changed, 73 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index f868478..1be8242 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,5 @@ func main() { } /** TODO - - queue for messages, both ways (Ack "queued" instead of "completed"), filled queue drops oldest entry - */ diff --git a/sim.go b/sim.go index 2309c1a..e534411 100644 --- a/sim.go +++ b/sim.go @@ -1,3 +1,5 @@ +//go:build sim + package main import ( @@ -13,7 +15,7 @@ import ( ) const ( - gatewayURL = "ws://localhost:3333/push" + gatewayURL = "ws://localhost:3333/sync" apiKey = "gateway" serverID = "test-server-001" ) diff --git a/ws/handlers.go b/ws/handlers.go index 5b0c696..838241f 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -9,7 +9,7 @@ import ( "github.com/gorilla/websocket" ) -func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) { +func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) { conn, err := wsg.validateAndUpgradeConnection(w, r) if err != nil { return @@ -65,7 +65,7 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) return } - wsg.registerConn(conn, meta) + wsg.registerConn(conn, meta, "mod") wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) go wsg.modReadLoop(conn, meta) // replace with external handler mayhaps @@ -81,11 +81,15 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) meta.ID = bhs.BotID + if ok := wsg.registerConn(conn, meta, "bot"); !ok { + wsg.sendWebsocketError(conn, "Bot already connected.", 409) + return + } + if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { return } - wsg.registerConn(conn, meta) wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) go wsg.botReadLoop(conn, meta) // replace with external handler mayhaps @@ -97,12 +101,8 @@ func (wsg *WebsocketGateway) handlePush(w http.ResponseWriter, r *http.Request) } } -func (wsg *WebsocketGateway) handleReady(w http.ResponseWriter, r *http.Request) {} - func (wsg *WebsocketGateway) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) _ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "healthy"}) } - -func (wsg *WebsocketGateway) handleRegister(w http.ResponseWriter, r *http.Request) {} diff --git a/ws/structs.go b/ws/structs.go index 1bbe67e..9cbd361 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -16,9 +16,9 @@ type WebsocketGateway struct { upgrader websocket.Upgrader - cache *cache.Cache - botConn *websocket.Conn - conns *cache.ConnectionCache + cache *cache.Cache + bot *websocket.Conn + conns *cache.ConnectionCache modHandler ModHandler botHandler BotHandler diff --git a/ws/temp.go b/ws/temp.go index c62057b..8a64b07 100644 --- a/ws/temp.go +++ b/ws/temp.go @@ -40,6 +40,13 @@ func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayModMessageIn ForwardedAt: time.Now().UTC(), } + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + + if err := conn.WriteJSON(fwd); err != nil { + _ = conn.Close() + return err + } + b, _ := json.Marshal(fwd) h.logger.Info("received mod message", "msg_id", msg.MsgID, "server", msg.Server, "Author", msg.Author.Name, "content", msg.Content) h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) diff --git a/ws/util.go b/ws/util.go index f695de3..376dda7 100644 --- a/ws/util.go +++ b/ws/util.go @@ -105,17 +105,33 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) - logger.Info("http request", "remote", r.RemoteAddr, "method", r.Method, "path", r.URL.Path, "duration", time.Since(start)) + logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start)) }) } // connections -func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData) { +func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) bool { + if typ == "bot" { + if wsg.bot != nil { + return false + } + + wsg.bot = conn + return true + } + wsg.conns.Set(meta.ID, conn, &meta) + return true } -func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData) { +func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) { + if typ == "bot" { + _ = wsg.bot.Close() + wsg.bot = nil + return + } + wsg.conns.RemoveById(meta.ID) _ = conn.Close() } @@ -123,8 +139,15 @@ func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.Con func (wsg *WebsocketGateway) closeAll() { wsg.logger.Info("Closing all websocket connections.") + if wsg.bot != nil { + _ = wsg.bot.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) + _ = wsg.bot.Close() + } + wsg.conns.Range(func(id string, c *websocket.Conn, meta *cache.ConnectionMetaData) { _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) _ = c.Close() }) + + wsg.conns.Clear() } diff --git a/ws/websocket.go b/ws/websocket.go index a5ba800..a9f1cd8 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -46,10 +46,8 @@ func (wsg *WebsocketGateway) Start() error { func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error { mux := http.NewServeMux() - mux.HandleFunc("/push", wsg.handlePush) - mux.HandleFunc("/ready", wsg.handleReady) + mux.HandleFunc("/sync", wsg.handleSync) mux.HandleFunc("/health", wsg.handleHealth) - mux.HandleFunc("/register", wsg.handleRegister) srv := &http.Server{ Addr: listenAddr, @@ -73,8 +71,8 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { defer func() { - wsg.unregisterConn(conn, meta) - wsg.logger.Info("Mod-Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) + wsg.unregisterConn(conn, meta, "mod") + wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) }() ticker := time.NewTicker(30 * time.Second) @@ -91,9 +89,7 @@ func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.Connec if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("unexpected mod close", "server_id", meta.ID, "err", err) - } else { - wsg.logger.Debug("mod read error", "server_id", meta.ID, "err", err) + wsg.logger.Warn("Mod-Client unexpectedly closed the connection.", "err", err) } return } @@ -129,7 +125,7 @@ func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.Connec func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { defer func() { - wsg.unregisterConn(conn, meta) + wsg.unregisterConn(conn, meta, "bot") wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) }() @@ -187,3 +183,25 @@ func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.Connec _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) } } + +func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) { + defer func() { + wsg.unregisterConn(conn, meta, typ) + wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String()) + }() + + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + go func() { + for range ticker.C { + wsg.sendWebsocketPing(conn) + } + }() + + switch typ { + case "mod": + + case "bot": + } +} From de03c1fe3d8bc821215254e20653d40cd2492c79 Mon Sep 17 00:00:00 2001 From: Overlord Date: Sun, 7 Dec 2025 00:48:09 +0100 Subject: [PATCH 10/15] almost done, queue fix next --- controller/controller.go | 5 +- main.go | 4 - sim.go | 33 ++-- util/queue/structs.go | 66 +++++++ ws/handlers.go | 18 +- ws/structs.go | 81 +++------ ws/temp.go | 121 +++++-------- ws/util.go | 6 +- ws/validate.go | 25 +-- ws/websocket.go | 374 +++++++++++++++++++++++++-------------- 10 files changed, 414 insertions(+), 319 deletions(-) create mode 100644 util/queue/structs.go diff --git a/controller/controller.go b/controller/controller.go index 86873b4..2584fe1 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -12,11 +12,8 @@ func NewGatewayController(cfg config.Config) GatewayController { panic(err) } - modHandler := ws.NewLoggingModHandler(wsl) - botHandler := ws.NewLoggingBotHandler(wsl) - return GatewayController{ - Websocket: ws.NewWebsocketGateway(cfg.Gateway, wsl, wCloseFn, modHandler, botHandler), + Websocket: ws.NewWebsocketGateway(cfg.Gateway, wsl, wCloseFn), HttpServer: HttpGateway{}, } } diff --git a/main.go b/main.go index 1be8242..fa5670f 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,3 @@ func main() { panic(err) } } - -/** TODO -- queue for messages, both ways (Ack "queued" instead of "completed"), filled queue drops oldest entry -*/ diff --git a/sim.go b/sim.go index e534411..6530720 100644 --- a/sim.go +++ b/sim.go @@ -34,23 +34,24 @@ type GatewayAck struct { Type string `json:"type"` } -type MinecraftUser struct { +type User struct { ID string `json:"id"` Name string `json:"name"` } type Destination struct { - ChannelID string `json:"channel_id"` + ID string `json:"channel_id,omitempty"` } -type GatewayModMessageIn struct { - MsgID string `json:"msg_id"` - Server string `json:"server"` - Destination Destination `json:"destination"` - Author MinecraftUser `json:"author"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` +type GatewayMessageIn struct { + ID string `json:"id"` // where am I from (channel_id or server_id) + MsgID string `json:"msg_id"` // msg id + Destination Destination `json:"destination,omitempty"` // where do I wanna go (channel_id or empty if from Bot) + Author User `json:"author"` // who sent the message + Content string `json:"content"` // message content + Meta map[string]interface{} `json:"meta,omitempty"` // additional metadata + Ts time.Time `json:"ts,omitempty"` // timestamp + ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) } func main() { @@ -140,18 +141,18 @@ func main() { // Optional: Send a test message after connecting time.Sleep(2 * time.Second) - testMsg := GatewayModMessageIn{ - MsgID: "test-msg-001", - Server: serverID, + testMsg := GatewayMessageIn{ + MsgID: "test-msg-001", + ID: serverID, Destination: Destination{ - ChannelID: "123456789", + ID: "123456789", }, - Author: MinecraftUser{ + Author: User{ ID: "player-uuid-123", Name: "TestPlayer", }, Content: "Hello from simulated mod!", - Ts: time.Now().UTC().Format(time.RFC3339), + Ts: time.Now().UTC(), } if err := conn.WriteJSON(testMsg); err != nil { diff --git a/util/queue/structs.go b/util/queue/structs.go new file mode 100644 index 0000000..0588606 --- /dev/null +++ b/util/queue/structs.go @@ -0,0 +1,66 @@ +package queue + +import ( + "sync" +) + +// Queue is a thin wrapper around a buffered channel. +type Queue[T any] struct { + ch chan T + closedMu sync.Mutex + closed bool + closedCh chan struct{} +} + +func NewQueue[T any](capacity int) *Queue[T] { + if capacity <= 0 { + panic("capacity > 0 required") + } + + return &Queue[T]{ + ch: make(chan T, capacity), + closedCh: make(chan struct{}), + } +} + +// Cap returns capacity. +func (q *Queue[T]) Cap() int { return cap(q.ch) } + +// Len returns current length (snapshot). +func (q *Queue[T]) Len() int { return len(q.ch) } + +// Enqueue returns immediately: true if enqueued, false otherwise. +func (q *Queue[T]) Enqueue(v T) bool { + select { + case q.ch <- v: + return true + default: + return false + } +} + +// Dequeue returns immediately (item, true) or (zero, false) if empty. +func (q *Queue[T]) Dequeue() (T, bool) { + var zero T + select { + case v := <-q.ch: + return v, true + default: + return zero, false + } +} + +// Close closes the queue. Further Enqueue attempts return ErrClosed. Consumers drain until channel empty then see ErrClosed. +func (q *Queue[T]) Close() { + q.closedMu.Lock() + + if q.closed { + q.closedMu.Unlock() + return + } + + q.closed = true + close(q.closedCh) + close(q.ch) + q.closedMu.Unlock() +} diff --git a/ws/handlers.go b/ws/handlers.go index 838241f..d05b927 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -29,20 +29,20 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) typ, data, err := conn.ReadMessage() if err != nil { - wsg.sendWebsocketError(conn, "Internal Server Error", 500) + wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) wsg.logger.Error("Failed to read handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } if typ != websocket.TextMessage && typ != websocket.BinaryMessage { - wsg.sendWebsocketError(conn, "First message must be a handshake.", 400) + wsg.sendWebsocketError(conn, "Initial message must be a handshake.", 400, true) wsg.logger.Warn("Invalid handshake message type.", "remote", conn.RemoteAddr().String()) return } var handshake Handshake if err := json.Unmarshal(data, &handshake); err != nil { - wsg.sendWebsocketError(conn, "Malformed handshake.", 400) + wsg.sendWebsocketError(conn, "Malformed handshake.", 400, true) wsg.logger.Warn("Malformed handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } @@ -54,7 +54,7 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) var mhs ModHandshake if err := json.Unmarshal(handshake.Data, &mhs); err != nil { - wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400) + wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400, true) wsg.logger.Warn("Malformed mod handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } @@ -68,13 +68,13 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) wsg.registerConn(conn, meta, "mod") wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) - go wsg.modReadLoop(conn, meta) // replace with external handler mayhaps + go wsg.read(conn, meta, "mod") case "bot": var bhs BotHandshake if err := json.Unmarshal(handshake.Data, &bhs); err != nil { - wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400) + wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400, true) wsg.logger.Warn("Malformed bot handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } @@ -82,7 +82,7 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) meta.ID = bhs.BotID if ok := wsg.registerConn(conn, meta, "bot"); !ok { - wsg.sendWebsocketError(conn, "Bot already connected.", 409) + wsg.sendWebsocketError(conn, "Bot already connected.", 409, true) return } @@ -92,10 +92,10 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) - go wsg.botReadLoop(conn, meta) // replace with external handler mayhaps + go wsg.read(conn, meta, "bot") default: - wsg.sendWebsocketError(conn, "Unknown handshake.", 400) + wsg.sendWebsocketError(conn, "Unknown handshake.", 400, true) wsg.logger.Warn("Unknown connection type.", "remote", conn.RemoteAddr().String(), "type", handshake.Type) return } diff --git a/ws/structs.go b/ws/structs.go index 9cbd361..96ea5b0 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -3,6 +3,7 @@ package ws import ( "encoding/json" "homestead/homestead_gateway/util/cache" + "homestead/homestead_gateway/util/queue" "log/slog" "time" @@ -17,73 +18,43 @@ type WebsocketGateway struct { upgrader websocket.Upgrader cache *cache.Cache + queue *queue.Queue[GatewayMessageOut] + bot *websocket.Conn conns *cache.ConnectionCache - modHandler ModHandler - botHandler BotHandler - logger *slog.Logger closeFn func() error } -type MinecraftUser struct { - ID string `json:"id"` - Name string `json:"name"` -} - -type DiscordUser struct { +type User struct { ID string `json:"id"` Name string `json:"name"` } type Destination struct { - ChannelID string `json:"channel_id"` + ID string `json:"channel_id,omitempty"` } -// GatewayModMessageIn : Mod -> Gateway -> Bot -type GatewayModMessageIn struct { - MsgID string `json:"msg_id"` - Server string `json:"server"` - Destination Destination `json:"destination"` - Author MinecraftUser `json:"author"` +type GatewayMessageIn struct { + Type string + ID string `json:"id"` // where am I from (channel_id or server_id) + MsgID string `json:"msg_id"` // msg id + Destination Destination `json:"destination,omitempty"` // where do I wanna go (channel_id or empty if from Bot) + Author User `json:"author"` // who sent the message + Content string `json:"content"` // message content + Meta map[string]interface{} `json:"meta,omitempty"` // additional metadata + Ts time.Time `json:"ts,omitempty"` // timestamp + ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) +} + +type GatewayMessageOut struct { + Type string `json:"type"` // "mod"|"bot" + ID string `json:"channel_id,omitempty"` // message.Destination.ID + Author User `json:"author"` Content string `json:"content"` Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` - ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) -} - -// GatewayBotMessageIn : Bot -> Gateway -> Mod -type GatewayBotMessageIn struct { - MsgID string `json:"msg_id"` - ChannelID string `json:"channel_id"` - Author DiscordUser `json:"author"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` - ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from bot) -} - -// GatewayModMessageOut : Gateway -> Bot -type GatewayModMessageOut struct { - Type string `json:"type"` // "mod" - ChannelID string `json:"channel_id"` - Author MinecraftUser `json:"author"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` - ReceivedAt time.Time `json:"received_at"` - ForwardedAt time.Time `json:"forwarded_at"` -} - -// GatewayBotMessageOut : Gateway -> Mod -type GatewayBotMessageOut struct { - Type string `json:"type"` // "bot" - ChannelID string `json:"channel_id"` - Author DiscordUser `json:"author"` - Content string `json:"content"` - Meta map[string]interface{} `json:"meta,omitempty"` - Ts string `json:"ts,omitempty"` + Ts time.Time `json:"ts,omitempty"` ReceivedAt time.Time `json:"received_at"` ForwardedAt time.Time `json:"forwarded_at"` } @@ -105,11 +76,3 @@ type ModHandshake struct { type BotHandshake struct { BotID string `json:"bot_id"` } - -type ModHandler interface { - Handle(conn *websocket.Conn, msg GatewayModMessageIn) error -} - -type BotHandler interface { - Handle(conn *websocket.Conn, msg GatewayBotMessageIn) error -} diff --git a/ws/temp.go b/ws/temp.go index 8a64b07..86ca50b 100644 --- a/ws/temp.go +++ b/ws/temp.go @@ -1,78 +1,47 @@ package ws -import ( - "encoding/json" - "log/slog" - "time" - - "github.com/gorilla/websocket" -) - -type LoggingModHandler struct { - logger *slog.Logger -} - -type LoggingBotHandler struct { - logger *slog.Logger -} - -func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { - return &LoggingModHandler{logger: logger} -} - -func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { - return &LoggingBotHandler{logger: logger} -} - -func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayModMessageIn) error { - // For now, just log and pretend it's being forwarded - // TODO: Look up channel_id from database using server - // TODO: Forward to bot connection(s) - - fwd := GatewayModMessageOut{ - Type: "mod", - ChannelID: "TODO", // will come from database lookup - Author: msg.Author, - Content: msg.Content, - Meta: msg.Meta, - Ts: msg.Ts, - ReceivedAt: msg.ReceivedAt, - ForwardedAt: time.Now().UTC(), - } - - _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) - - if err := conn.WriteJSON(fwd); err != nil { - _ = conn.Close() - return err - } - - b, _ := json.Marshal(fwd) - h.logger.Info("received mod message", "msg_id", msg.MsgID, "server", msg.Server, "Author", msg.Author.Name, "content", msg.Content) - h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) - - return nil -} - -func (h *LoggingBotHandler) Handle(conn *websocket.Conn, msg GatewayBotMessageIn) error { - // For now, just log and pretend it's being forwarded - // TODO: Look up server_id from database using channel_id - // TODO: Forward to mod connection(s) - - fwd := GatewayBotMessageOut{ - Type: "bot", - ChannelID: msg.ChannelID, - Author: msg.Author, - Content: msg.Content, - Meta: msg.Meta, - Ts: msg.Ts, - ReceivedAt: msg.ReceivedAt, - ForwardedAt: time.Now().UTC(), - } - - b, _ := json.Marshal(fwd) - h.logger.Info("received bot message", "msg_id", msg.MsgID, "channel", msg.ChannelID, "author", msg.Author, "content", msg.Content) - h.logger.Debug("forwarding bot message", "msg_id", msg.MsgID, "channel", msg.ChannelID, "payload", string(b)) - - return nil -} +//type LoggingModHandler struct { +// logger *slog.Logger +//} +// +//type LoggingBotHandler struct { +// logger *slog.Logger +//} +// +//func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { +// return &LoggingModHandler{logger: logger} +//} +// +//func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { +// return &LoggingBotHandler{logger: logger} +//} +// +//func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayMessageIn) error { +// // For now, just log and pretend it's being forwarded +// // TODO: Look up channel_id from database using server +// // TODO: Forward to bot connection(s) +// +// fwd := GatewayMessageOut{ +// Type: "mod", +// ChannelID: "TODO", // will come from database lookup +// Author: msg.Author, +// Content: msg.Content, +// Meta: msg.Meta, +// Ts: msg.Ts, +// ReceivedAt: msg.ReceivedAt, +// ForwardedAt: time.Now().UTC(), +// } +// +// _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) +// +// if err := conn.WriteJSON(fwd); err != nil { +// _ = conn.Close() +// return err +// } +// +// b, _ := json.Marshal(fwd) +// h.logger.Info("received mod message", "msg_id", msg.MsgID, "server", msg.Server, "Author", msg.Author.Name, "content", msg.Content) +// h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) +// +// return nil +//} diff --git a/ws/util.go b/ws/util.go index 376dda7..6003517 100644 --- a/ws/util.go +++ b/ws/util.go @@ -47,10 +47,12 @@ func (wsg *WebsocketGateway) sendWebsocketPing(conn *websocket.Conn) { _ = conn.WriteMessage(websocket.PingMessage, nil) } -func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int) { +func (wsg *WebsocketGateway) sendWebsocketError(conn *websocket.Conn, message string, code int, close bool) { _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _ = conn.WriteJSON(map[string]interface{}{"message": message, "code": code}) - _ = conn.Close() + if close { + _ = conn.Close() + } } func (wsg *WebsocketGateway) sendWebsocketResponse(conn *websocket.Conn, content interface{}) error { diff --git a/ws/validate.go b/ws/validate.go index 6e494b3..1431e47 100644 --- a/ws/validate.go +++ b/ws/validate.go @@ -5,34 +5,23 @@ import ( "strings" ) -func (m *GatewayModMessageIn) Validate() error { +func (m *GatewayMessageIn) Validate() error { + if strings.TrimSpace(m.ID) == "" { + return errors.New("id missing") + } if strings.TrimSpace(m.MsgID) == "" { return errors.New("msg_id missing") } - if strings.TrimSpace(m.Server) == "" { - return errors.New("server missing") - } if strings.TrimSpace(m.Author.ID) == "" { return errors.New("author.id missing") } if strings.TrimSpace(m.Content) == "" { return errors.New("content missing") } - return nil -} -func (m *GatewayBotMessageIn) Validate() error { - if strings.TrimSpace(m.MsgID) == "" { - return errors.New("msg_id missing") - } - if strings.TrimSpace(m.ChannelID) == "" { - return errors.New("channel_id missing") - } - if strings.TrimSpace(m.Author.ID) == "" { - return errors.New("author missing") - } - if strings.TrimSpace(m.Content) == "" { - return errors.New("content missing") + if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" { + return errors.New("destination.channel_id missing") } + return nil } diff --git a/ws/websocket.go b/ws/websocket.go index a9f1cd8..c9e8d64 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -6,6 +6,7 @@ import ( "fmt" "homestead/homestead_gateway/util/cache" "homestead/homestead_gateway/util/config" + "homestead/homestead_gateway/util/queue" "log/slog" "net" "net/http" @@ -16,16 +17,15 @@ import ( "github.com/gorilla/websocket" ) -func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error, modH ModHandler, botH BotHandler) *WebsocketGateway { +func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - modHandler: modH, - botHandler: botH, - port: cfg.HttpPort, - apiKey: cfg.Websocket, - cache: cache.NewCache(), - conns: cache.NewConnectionCache(), + logger: logger, + closeFn: closefn, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + cache: cache.NewCache(), + queue: queue.NewQueue[GatewayMessageOut](32), + conns: cache.NewConnectionCache(), upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -69,139 +69,251 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // -func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { +//func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { +// defer func() { +// wsg.unregisterConn(conn, meta, "mod") +// wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) +// }() +// +// ticker := time.NewTicker(30 * time.Second) +// defer ticker.Stop() +// +// go func() { +// for range ticker.C { +// wsg.sendWebsocketPing(conn) +// } +// }() +// +// for { +// typ, data, err := conn.ReadMessage() +// +// if err != nil { +// if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { +// wsg.logger.Warn("Mod-Client unexpectedly closed the connection.", "err", err) +// } +// return +// } +// +// if typ != websocket.TextMessage && typ != websocket.BinaryMessage { +// continue +// } +// +// var msg GatewayModMessageIn +// if err := json.Unmarshal(data, &msg); err != nil { +// _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) +// wsg.logger.Warn("invalid json from mod", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) +// continue +// } +// +// msg.ReceivedAt = time.Now().UTC() +// if err := msg.Validate(); err != nil { +// _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) +// wsg.logger.Warn("mod message validation failed", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) +// continue +// } +// +// // Handle the message (forward to bot, enrich, etc.) +// if err := wsg.modHandler.Handle(conn, msg); err != nil { +// _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) +// wsg.logger.Error("mod handler error", "server_id", meta.ID, "err", err) +// continue +// } +// +// _ = writeJSONSafe(conn, map[string]string{"status": "completed"}) // or "queued" +// } +//} +// +//func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { +// defer func() { +// wsg.unregisterConn(conn, meta, "bot") +// wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) +// }() +// +// pingTicker := time.NewTicker(30 * time.Second) +// defer pingTicker.Stop() +// +// // Send pings in a separate goroutine +// go func() { +// for range pingTicker.C { +// _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) +// if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { +// wsg.logger.Debug("write ping failed", "bot_id", meta.ID, "err", err) +// return +// } +// wsg.logger.Debug("sent ping to bot", "bot_id", meta.ID) +// } +// }() +// +// for { +// typ, data, err := conn.ReadMessage() +// if err != nil { +// if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { +// wsg.logger.Warn("unexpected bot close", "bot_id", meta.ID, "err", err) +// } else { +// wsg.logger.Debug("bot read error", "bot_id", meta.ID, "err", err) +// } +// return +// } +// +// if typ != websocket.TextMessage && typ != websocket.BinaryMessage { +// continue +// } +// +// var msg GatewayBotMessageIn +// if err := json.Unmarshal(data, &msg); err != nil { +// _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) +// wsg.logger.Warn("invalid json from bot", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) +// continue +// } +// +// msg.ReceivedAt = time.Now().UTC() +// if err := msg.Validate(); err != nil { +// _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) +// wsg.logger.Warn("bot message validation failed", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) +// continue +// } +// +// // Handle the message (forward to mod, enrich, etc.) +// if err := wsg.botHandler.Handle(conn, msg); err != nil { +// _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) +// wsg.logger.Error("bot handler error", "bot_id", meta.ID, "err", err) +// continue +// } +// +// _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) +// } +//} + +func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMetaData, _type string) { defer func() { - wsg.unregisterConn(conn, meta, "mod") - wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) - }() - - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - go func() { - for range ticker.C { - wsg.sendWebsocketPing(conn) - } - }() - - for { - typ, data, err := conn.ReadMessage() - - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("Mod-Client unexpectedly closed the connection.", "err", err) - } - return - } - - if typ != websocket.TextMessage && typ != websocket.BinaryMessage { - continue - } - - var msg GatewayModMessageIn - if err := json.Unmarshal(data, &msg); err != nil { - _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) - wsg.logger.Warn("invalid json from mod", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) - continue - } - - msg.ReceivedAt = time.Now().UTC() - if err := msg.Validate(); err != nil { - _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) - wsg.logger.Warn("mod message validation failed", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) - continue - } - - // Handle the message (forward to bot, enrich, etc.) - if err := wsg.modHandler.Handle(conn, msg); err != nil { - _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) - wsg.logger.Error("mod handler error", "server_id", meta.ID, "err", err) - continue - } - - _ = writeJSONSafe(conn, map[string]string{"status": "completed"}) // or "queued" - } -} - -func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { - defer func() { - wsg.unregisterConn(conn, meta, "bot") - wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) - }() - - pingTicker := time.NewTicker(30 * time.Second) - defer pingTicker.Stop() - - // Send pings in a separate goroutine - go func() { - for range pingTicker.C { - _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - wsg.logger.Debug("write ping failed", "bot_id", meta.ID, "err", err) - return - } - wsg.logger.Debug("sent ping to bot", "bot_id", meta.ID) - } - }() - - for { - typ, data, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - wsg.logger.Warn("unexpected bot close", "bot_id", meta.ID, "err", err) - } else { - wsg.logger.Debug("bot read error", "bot_id", meta.ID, "err", err) - } - return - } - - if typ != websocket.TextMessage && typ != websocket.BinaryMessage { - continue - } - - var msg GatewayBotMessageIn - if err := json.Unmarshal(data, &msg); err != nil { - _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) - wsg.logger.Warn("invalid json from bot", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) - continue - } - - msg.ReceivedAt = time.Now().UTC() - if err := msg.Validate(); err != nil { - _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) - wsg.logger.Warn("bot message validation failed", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) - continue - } - - // Handle the message (forward to mod, enrich, etc.) - if err := wsg.botHandler.Handle(conn, msg); err != nil { - _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) - wsg.logger.Error("bot handler error", "bot_id", meta.ID, "err", err) - continue - } - - _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) - } -} - -func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) { - defer func() { - wsg.unregisterConn(conn, meta, typ) + wsg.unregisterConn(conn, meta, _type) wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String()) }() ticker := time.NewTicker(30 * time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer ticker.Stop() + defer cancel() go func() { - for range ticker.C { - wsg.sendWebsocketPing(conn) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + wsg.sendWebsocketPing(conn) + } } }() - switch typ { - case "mod": + for { + typ, data, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + wsg.logger.Error("Client unexpectedly closed the connection.", "err", err) + } + return + } - case "bot": + if typ != websocket.TextMessage && typ != websocket.BinaryMessage { + continue + } + + ts := time.Now().UTC() + var message GatewayMessageIn + if err := json.Unmarshal(data, &message); err != nil { + wsg.sendWebsocketError(conn, "Malformed message.", 400, false) + wsg.logger.Warn("Received malformed message json from client.", "remote", conn.RemoteAddr().String(), "err", err) + continue + } + + message.Type = _type + message.ReceivedAt = ts + if err := message.Validate(); err != nil { + wsg.sendWebsocketError(conn, "Malformed message.", 400, false) + wsg.logger.Warn("Received malformed message json from client; validation failed.", "remote", conn.RemoteAddr().String(), "err", err) + continue + } + + var ok bool + var destConn *websocket.Conn + + switch message.Type { + case "mod": + if wsg.bot != nil { + ok = true + destConn = wsg.bot + } else { + ok = false + } + case "bot": + var id string + id, ok = wsg.cache.GetByChannelId(message.ID) + if ok { + var dest *websocket.Conn + dest, _, ok = wsg.conns.GetById(id) + if ok { + destConn = dest + } else { + wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) + wsg.logger.Error("Invalid cache structure.", "remote", conn.RemoteAddr().String(), "id", id) + return + } + } + default: + panic("invalid message type") + } + + if ok { + if destConn == nil { + wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) + wsg.logger.Error("Destination connection unavailable.", "remote", conn.RemoteAddr().String()) + return + } + + err = wsg.sendWebsocketResponse(destConn, GatewayMessageOut{ + Type: message.Type, + ID: message.Destination.ID, + Author: message.Author, + Content: message.Content, + Meta: message.Meta, + Ts: message.Ts, + ReceivedAt: message.ReceivedAt, + ForwardedAt: time.Now().UTC(), + }) + + if err != nil { + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) + wsg.logger.Error("Failed to forward message.", "remote", conn.RemoteAddr().String(), "err", err) + continue + } + + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "completed", Type: message.Type}) + continue + } + + if message.Type == "mod" { + wsg.cache.Set(message.ID, message.Destination.ID) + } + + out := GatewayMessageOut{ + Type: message.Type, + ID: message.Destination.ID, + Author: message.Author, + Content: message.Content, + Meta: message.Meta, + Ts: message.Ts, + ReceivedAt: message.ReceivedAt, + ForwardedAt: time.Now().UTC(), + } + + queued := wsg.queue.Enqueue(out) + if !queued { + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) + wsg.logger.Warn("Failed to queue message.", "remote", conn.RemoteAddr().String()) + continue + } + + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "queued", Type: message.Type}) } } From 3d6774586eff40a135a7820d8595947b20730983 Mon Sep 17 00:00:00 2001 From: Overlord Date: Sun, 7 Dec 2025 17:11:37 +0100 Subject: [PATCH 11/15] Gateway working, beta --- bot.go | 195 ++++++++++++++++++++++++++++++++++ sim.go | 36 +++---- util/cache/cache.go | 91 ---------------- util/cache/conn_cache.go | 88 ---------------- util/cache/structs.go | 37 ------- util/queue/structs.go | 66 ------------ ws/handlers.go | 40 +++---- ws/registry.go | 219 +++++++++++++++++++++++++++++++++++++++ ws/structs.go | 50 +++++++-- ws/temp.go | 47 --------- ws/util.go | 119 ++++++++++++++++----- ws/validate.go | 27 ----- ws/websocket.go | 107 ++++++------------- 13 files changed, 616 insertions(+), 506 deletions(-) create mode 100644 bot.go delete mode 100644 util/cache/cache.go delete mode 100644 util/cache/conn_cache.go delete mode 100644 util/cache/structs.go delete mode 100644 util/queue/structs.go create mode 100644 ws/registry.go delete mode 100644 ws/temp.go delete mode 100644 ws/validate.go diff --git a/bot.go b/bot.go new file mode 100644 index 0000000..b4069ae --- /dev/null +++ b/bot.go @@ -0,0 +1,195 @@ +//go:build simbot + +package main + +import ( + "encoding/json" + "flag" + "log" + "net/url" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gorilla/websocket" +) + +const ( + gatewayURL = "ws://localhost:3333/sync" + apiKey = "gateway" + // must match the mod's channel id used by your mod simulator + channelID = "123456789" +) + +type Handshake struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +type BotHandshake struct { + ChannelId string `json:"channel_id"` // match gateway field exactly +} + +type GatewayAck struct { + Status string `json:"status"` + Type string `json:"type"` +} + +type User struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type Destination struct { + ID string `json:"channel_id,omitempty"` +} + +type GatewayMessageIn struct { + ID string `json:"id"` + MsgID string `json:"msg_id"` + Destination Destination `json:"destination,omitempty"` + Author User `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts time.Time `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` +} + +func main() { + var ( + botID = flag.String("bot", "sim-bot-1", "bot id") + sendAfter = flag.Duration("send-after", 0, "optional: send a bot->mod message after this delay (e.g. 2s)") + sendMsg = flag.String("msg", "Hello from bot!", "optional bot->mod test message content") + ) + flag.Parse() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) + + u, err := url.Parse(gatewayURL) + if err != nil { + log.Fatalf("Failed to parse URL: %v", err) + } + q := u.Query() + q.Set("api_key", apiKey) + u.RawQuery = q.Encode() + + log.Printf("Connecting to %s", u.String()) + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + // we intentionally don't defer conn.Close() here immediately; we'll do on shutdown + log.Println("Connected to gateway") + + // handle server pings by replying a Pong (safe) + conn.SetPingHandler(func(appData string) error { + log.Println("Received ping from server, sending pong") + return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) + }) + + // send bot handshake (must at least include bot_id) + bhs := BotHandshake{ChannelId: channelID} + data, err := json.Marshal(bhs) + if err != nil { + _ = conn.Close() + log.Fatalf("Failed to marshal bot handshake: %v", err) + } + hs := Handshake{Type: "bot", Data: data} + if err := conn.WriteJSON(hs); err != nil { + _ = conn.Close() + log.Fatalf("Failed to send handshake: %v", err) + } + log.Println("Handshake sent (bot)") + + // Instead of ReadJSON, read raw message so we can see whatever the server returns (or why it closes). + // This avoids a silent parsing failure if server returns non-JSON or closes immediately. + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + msgType, raw, err := conn.ReadMessage() + if err != nil { + _ = conn.Close() + log.Fatalf("Failed to read handshake response: %v", err) + } + _ = conn.SetReadDeadline(time.Time{}) // clear deadline + + if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { + log.Printf("Raw handshake reply: %s", string(raw)) + var ack GatewayAck + if err := json.Unmarshal(raw, &ack); err != nil { + log.Printf("Handshake reply is not JSON or unmarshal failed: %v", err) + } else { + log.Printf("Parsed ack: status=%q, type=%q", ack.Status, ack.Type) + } + } else { + log.Printf("Handshake reply was control frame or unexpected type=%d", msgType) + } + + // From here, start the normal read loop. + done := make(chan struct{}) + go func() { + defer close(done) + for { + msgType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + log.Printf("WebSocket error: %v", err) + } else { + log.Printf("Connection closed or read error: %v", err) + } + return + } + + if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { + log.Printf("Received from gateway: %s", string(message)) + } + } + }() + + // optionally send a bot->mod message after a delay + if *sendAfter > 0 { + go func() { + time.Sleep(*sendAfter) + msg := GatewayMessageIn{ + MsgID: "bot-msg-001", + ID: channelID, // bot reports channel id as ID + Author: User{ + ID: *botID, + Name: "SimBot", + }, + Content: *sendMsg, + Ts: time.Now().UTC(), + } + // set a write deadline + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send bot->mod test message: %v", err) + return + } + log.Printf("Sent bot->mod test message (channel=%s)", channelID) + _ = conn.SetWriteDeadline(time.Time{}) + }() + } + + log.Println("Bot simulator running. Press Ctrl+C to exit.") + + // Wait for interrupt or read loop done + for { + select { + case <-done: + log.Println("Connection closed by server") + _ = conn.Close() + return + case <-interrupt: + log.Println("Interrupt received, closing connection...") + // politely close + _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + select { + case <-done: + case <-time.After(time.Second): + } + _ = conn.Close() + return + } + } +} diff --git a/sim.go b/sim.go index 6530720..ce2a105 100644 --- a/sim.go +++ b/sim.go @@ -18,6 +18,8 @@ const ( gatewayURL = "ws://localhost:3333/sync" apiKey = "gateway" serverID = "test-server-001" + // THE CHANNEL ID the mod says it serves. Must match gateway expectation. + channelID = "123456789" ) type Handshake struct { @@ -25,8 +27,10 @@ type Handshake struct { Data json.RawMessage `json:"data"` } +// ModHandshake now includes ChannelID type ModHandshake struct { - ServerID string `json:"server_id"` + ServerID string `json:"server_id"` + ChannelID string `json:"channel_id"` } type GatewayAck struct { @@ -58,7 +62,6 @@ func main() { interrupt := make(chan os.Signal, 1) signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) - // Build WebSocket URL with API key u, err := url.Parse(gatewayURL) if err != nil { log.Fatalf("Failed to parse URL: %v", err) @@ -69,7 +72,6 @@ func main() { log.Printf("Connecting to %s", u.String()) - // Connect to WebSocket conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { log.Fatalf("Failed to connect: %v", err) @@ -78,7 +80,8 @@ func main() { log.Println("Connected to gateway") - // Set up ping handler - respond to pings from server + // respond to pings (server ping -> client must pong). Using SetPingHandler is fine, + // but WriteControl for Pong is acceptable too. conn.SetPingHandler(func(appData string) error { log.Println("Received ping from server, sending pong") err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) @@ -89,8 +92,11 @@ func main() { return nil }) - // Send handshake - modHS := ModHandshake{ServerID: serverID} + // Build and send handshake including channel id + modHS := ModHandshake{ + ServerID: serverID, + ChannelID: channelID, + } modHSData, err := json.Marshal(modHS) if err != nil { log.Fatalf("Failed to marshal mod handshake: %v", err) @@ -104,21 +110,17 @@ func main() { if err := conn.WriteJSON(handshake); err != nil { log.Fatalf("Failed to send handshake: %v", err) } - log.Println("Handshake sent") - // Read acknowledgment + // Read acknowledgment (some servers might not reply with JSON; handle errors) var ack GatewayAck if err := conn.ReadJSON(&ack); err != nil { log.Fatalf("Failed to read acknowledgment: %v", err) } + log.Printf("Received acknowledgment: status=%q, type=%q", ack.Status, ack.Type) - log.Printf("Received acknowledgment: status=%s, type=%s", ack.Status, ack.Type) - - // Channel for incoming messages done := make(chan struct{}) - // Read loop - handles incoming messages and processes control frames go func() { defer close(done) for { @@ -131,21 +133,19 @@ func main() { } return } - - // Only log text/binary messages (ping/pong handled by handlers) if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { log.Printf("Received from server: %s", string(message)) } } }() - // Optional: Send a test message after connecting - time.Sleep(2 * time.Second) + // Optional: send a test message after connecting + time.Sleep(1 * time.Second) testMsg := GatewayMessageIn{ MsgID: "test-msg-001", ID: serverID, Destination: Destination{ - ID: "123456789", + ID: channelID, }, Author: User{ ID: "player-uuid-123", @@ -163,7 +163,6 @@ func main() { log.Println("Connection established. Responding to pings. Press Ctrl+C to disconnect.") - // Wait for interrupt or connection close for { select { case <-done: @@ -172,7 +171,6 @@ func main() { case <-interrupt: log.Println("Interrupt received, closing connection...") - // Send close message err := conn.WriteMessage( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), diff --git a/util/cache/cache.go b/util/cache/cache.go deleted file mode 100644 index 57d682c..0000000 --- a/util/cache/cache.go +++ /dev/null @@ -1,91 +0,0 @@ -package cache - -func NewCache() *Cache { - c := &Cache{} - c.state.Store(&mappings{ - s2c: make(map[string]string), - c2s: make(map[string]string), - }) - - return c -} - -func (c *Cache) Set(serverId, channelId string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - next := current.clone() - - if oldCh, ok := next.s2c[serverId]; ok && oldCh != channelId { - delete(next.c2s, oldCh) - } - if oldSrv, ok := next.c2s[channelId]; ok && oldSrv != serverId { - delete(next.s2c, oldSrv) - } - - next.s2c[serverId] = channelId - next.c2s[channelId] = serverId - - c.state.Store(next) -} - -func (c *Cache) GetByServerId(serverId string) (string, bool) { - m := c.state.Load() - val, ok := m.s2c[serverId] - return val, ok -} - -func (c *Cache) GetByChannelId(channelId string) (string, bool) { - m := c.state.Load() - val, ok := m.c2s[channelId] - return val, ok -} - -func (c *Cache) RemoveByServerId(serverId string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - if _, ok := current.s2c[serverId]; !ok { - return - } - - next := current.clone() - if channelId, ok := next.s2c[serverId]; ok { - delete(next.s2c, serverId) - delete(next.c2s, channelId) - } - c.state.Store(next) -} - -func (c *Cache) RemoveByChannelId(channelId string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - if _, ok := current.c2s[channelId]; !ok { - return - } - - next := current.clone() - if serverId, ok := next.c2s[channelId]; ok { - delete(next.c2s, channelId) - delete(next.s2c, serverId) - } - c.state.Store(next) -} - -func (m *mappings) clone() *mappings { - newM := &mappings{ - s2c: make(map[string]string, len(m.s2c)), - c2s: make(map[string]string, len(m.c2s)), - } - for k, v := range m.s2c { - newM.s2c[k] = v - } - for k, v := range m.c2s { - newM.c2s[k] = v - } - return newM -} diff --git a/util/cache/conn_cache.go b/util/cache/conn_cache.go deleted file mode 100644 index a49a6fe..0000000 --- a/util/cache/conn_cache.go +++ /dev/null @@ -1,88 +0,0 @@ -package cache - -import ( - "github.com/gorilla/websocket" -) - -func NewConnectionCache() *ConnectionCache { - c := &ConnectionCache{} - c.state.Store(&connectionMappings{ - id2c: make(map[string]*conn), - }) - - return c -} - -func (c *ConnectionCache) Set(id string, connection *websocket.Conn, meta *ConnectionMetaData) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - next := current.clone() - - next.id2c[id] = &conn{ - connection: connection, - meta: meta, - } - - c.state.Store(next) -} - -func (c *ConnectionCache) GetById(id string) (*websocket.Conn, *ConnectionMetaData, bool) { - m := c.state.Load() - if conn, ok := m.id2c[id]; ok { - return conn.connection, conn.meta, true - } - - return nil, nil, false -} - -func (c *ConnectionCache) RemoveById(id string) { - c.mu.Lock() - defer c.mu.Unlock() - - current := c.state.Load() - if _, ok := current.id2c[id]; !ok { - return - } - - next := current.clone() - delete(next.id2c, id) - - c.state.Store(next) -} - -func (c *ConnectionCache) Range(fn func(id string, connection *websocket.Conn, meta *ConnectionMetaData)) { - m := c.state.Load() - for id, conn := range m.id2c { - fn(id, conn.connection, conn.meta) - } -} - -func (c *ConnectionCache) GetAllConnections() []*websocket.Conn { - m := c.state.Load() - conns := make([]*websocket.Conn, 0, len(m.id2c)) - for _, conn := range m.id2c { - conns = append(conns, conn.connection) - } - return conns -} - -func (c *ConnectionCache) Clear() { - c.mu.Lock() - defer c.mu.Unlock() - - c.state.Store(&connectionMappings{ - id2c: make(map[string]*conn), - }) -} - -func (cm *connectionMappings) clone() *connectionMappings { - newCM := &connectionMappings{ - id2c: make(map[string]*conn, len(cm.id2c)), - } - for k, v := range cm.id2c { - newCM.id2c[k] = v - } - return newCM -} diff --git a/util/cache/structs.go b/util/cache/structs.go deleted file mode 100644 index 37aac37..0000000 --- a/util/cache/structs.go +++ /dev/null @@ -1,37 +0,0 @@ -package cache - -import ( - "sync" - "sync/atomic" - - "github.com/gorilla/websocket" -) - -type Cache struct { - mu sync.Mutex - state atomic.Pointer[mappings] -} - -type ConnectionCache struct { - mu sync.Mutex - state atomic.Pointer[connectionMappings] -} - -type mappings struct { - s2c map[string]string - c2s map[string]string -} - -type connectionMappings struct { - id2c map[string]*conn -} - -type ConnectionMetaData struct { - ConnectionType string // "mod" or "bot" - ID string // server_id or bot_id for logging -} - -type conn struct { - connection *websocket.Conn - meta *ConnectionMetaData -} diff --git a/util/queue/structs.go b/util/queue/structs.go deleted file mode 100644 index 0588606..0000000 --- a/util/queue/structs.go +++ /dev/null @@ -1,66 +0,0 @@ -package queue - -import ( - "sync" -) - -// Queue is a thin wrapper around a buffered channel. -type Queue[T any] struct { - ch chan T - closedMu sync.Mutex - closed bool - closedCh chan struct{} -} - -func NewQueue[T any](capacity int) *Queue[T] { - if capacity <= 0 { - panic("capacity > 0 required") - } - - return &Queue[T]{ - ch: make(chan T, capacity), - closedCh: make(chan struct{}), - } -} - -// Cap returns capacity. -func (q *Queue[T]) Cap() int { return cap(q.ch) } - -// Len returns current length (snapshot). -func (q *Queue[T]) Len() int { return len(q.ch) } - -// Enqueue returns immediately: true if enqueued, false otherwise. -func (q *Queue[T]) Enqueue(v T) bool { - select { - case q.ch <- v: - return true - default: - return false - } -} - -// Dequeue returns immediately (item, true) or (zero, false) if empty. -func (q *Queue[T]) Dequeue() (T, bool) { - var zero T - select { - case v := <-q.ch: - return v, true - default: - return zero, false - } -} - -// Close closes the queue. Further Enqueue attempts return ErrClosed. Consumers drain until channel empty then see ErrClosed. -func (q *Queue[T]) Close() { - q.closedMu.Lock() - - if q.closed { - q.closedMu.Unlock() - return - } - - q.closed = true - close(q.closedCh) - close(q.ch) - q.closedMu.Unlock() -} diff --git a/ws/handlers.go b/ws/handlers.go index d05b927..512bf9e 100644 --- a/ws/handlers.go +++ b/ws/handlers.go @@ -2,7 +2,6 @@ package ws import ( "encoding/json" - "homestead/homestead_gateway/util/cache" "net/http" "time" @@ -47,52 +46,53 @@ func (wsg *WebsocketGateway) handleSync(w http.ResponseWriter, r *http.Request) return } - meta := cache.ConnectionMetaData{ConnectionType: handshake.Type} - switch handshake.Type { case "mod": var mhs ModHandshake - if err := json.Unmarshal(handshake.Data, &mhs); err != nil { wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400, true) wsg.logger.Warn("Malformed mod handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta.ID = mhs.ServerID - - if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}); err != nil { + if mhs.ServerID == "" || mhs.ChannelID == "" { + wsg.sendWebsocketError(conn, "Malformed mod handshake.", 400, true) return } - wsg.registerConn(conn, meta, "mod") - wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String(), "server_id", mhs.ServerID) + if !wsg.registerConn(conn, "mod", mhs.ChannelID, mhs.ServerID) { + wsg.sendWebsocketError(conn, "Failed to register mod.", 500, true) + return + } - go wsg.read(conn, meta, "mod") + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "mod"}) + wsg.registry.FlushChannelWithSender(mhs.ChannelID, wsg.flush) + + wsg.logger.Info("Mod connected via Websocket.", "remote", conn.RemoteAddr().String()) + go wsg.read(conn, "mod", mhs.ChannelID) case "bot": var bhs BotHandshake - if err := json.Unmarshal(handshake.Data, &bhs); err != nil { wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400, true) - wsg.logger.Warn("Malformed bot handshake.", "remote", conn.RemoteAddr().String(), "err", err) return } - meta.ID = bhs.BotID + if bhs.ChannelId == "" { + wsg.sendWebsocketError(conn, "Malformed bot handshake.", 400, true) + return + } - if ok := wsg.registerConn(conn, meta, "bot"); !ok { + if !wsg.registerConn(conn, "bot", bhs.ChannelId, "") { wsg.sendWebsocketError(conn, "Bot already connected.", 409, true) return } - if err = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}); err != nil { - return - } + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "connected", Type: "bot"}) + wsg.registry.FlushAllToBotWithSender(wsg.flush) - wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String(), "bot_id", bhs.BotID) - - go wsg.read(conn, meta, "bot") + wsg.logger.Info("Bot connected via Websocket.", "remote", conn.RemoteAddr().String()) + go wsg.read(conn, "bot", bhs.ChannelId) default: wsg.sendWebsocketError(conn, "Unknown handshake.", 400, true) diff --git a/ws/registry.go b/ws/registry.go new file mode 100644 index 0000000..64be288 --- /dev/null +++ b/ws/registry.go @@ -0,0 +1,219 @@ +package ws + +import ( + "fmt" + "time" + + "github.com/gorilla/websocket" +) + +func (q *BoundedQueue) Enqueue(m GatewayMessageOut) bool { + q.mu.Lock() + defer q.mu.Unlock() + if q.capacity == 0 { + return false + } + if q.length < q.capacity { + q.buf[(q.start+q.length)%q.capacity] = m + q.length++ + return true + } + // overwrite oldest + q.buf[q.start] = m + q.start = (q.start + 1) % q.capacity + return true +} + +func (q *BoundedQueue) PopAll() []GatewayMessageOut { + q.mu.Lock() + defer q.mu.Unlock() + if q.length == 0 { + return nil + } + out := make([]GatewayMessageOut, 0, q.length) + for i := 0; i < q.length; i++ { + out = append(out, q.buf[(q.start+i)%q.capacity]) + } + q.start = 0 + q.length = 0 + return out +} + +func (q *BoundedQueue) Len() int { + q.mu.Lock() + defer q.mu.Unlock() + return q.length +} + +// + +func (r *Registry) getOrCreate(channel string) *ChannelEntry { + r.mu.RLock() + e := r.entries[channel] + r.mu.RUnlock() + if e != nil { + return e + } + + r.mu.Lock() + defer r.mu.Unlock() + if e = r.entries[channel]; e == nil { + e = newChannelEntry(channel, r.queueCap) + r.entries[channel] = e + } + return e +} + +// + +// RegisterMod : map channel_id -> mod conn (serverID) +func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) { + e := r.getOrCreate(channelID) + e.mu.Lock() + defer e.mu.Unlock() + if e.Mod != nil && e.Mod.Conn != nil { + _ = e.Mod.Conn.Close() + } + e.Mod = &ConnWrapper{Conn: conn, ServerID: serverID, LastSeen: time.Now()} + // flush queued bot->mod messages for this channel + // caller should use FlushChannelWithSender to perform actual sends +} + +// UnregisterMod : remove mod for a channel +func (r *Registry) UnregisterMod(channelID string) { + r.mu.RLock() + e := r.entries[channelID] + r.mu.RUnlock() + if e == nil { + return + } + e.mu.Lock() + defer e.mu.Unlock() + if e.Mod != nil && e.Mod.Conn != nil { + _ = e.Mod.Conn.Close() + } + e.Mod = nil +} + +// RegisterBot : single connection for bot. after registration call FlushAllToBotWithSender +func (r *Registry) RegisterBot(conn *websocket.Conn) { + r.botMu.Lock() + if r.bot != nil && r.bot.Conn != nil { + _ = r.bot.Conn.Close() + } + r.bot = &ConnWrapper{Conn: conn, LastSeen: time.Now()} + r.botMu.Unlock() +} + +func (r *Registry) UnregisterBot() { + r.botMu.Lock() + if r.bot != nil && r.bot.Conn != nil { + _ = r.bot.Conn.Close() + } + r.bot = nil + r.botMu.Unlock() +} + +func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn func(*websocket.Conn, GatewayMessageOut) error) (delivered bool, queued bool, err error) { + if out.Type == "mod" { + r.botMu.Lock() + b := r.bot + r.botMu.Unlock() + if b != nil && b.Conn != nil { + if err := sendOverConn(b.Conn, out); err == nil { + return true, false, nil + } + _ = b.Conn.Close() + r.UnregisterBot() + } + + e := r.getOrCreate(channelID) + e.mu.Lock() + enq := e.Queue.Enqueue(out) + e.mu.Unlock() + if !enq { + return false, false, fmt.Errorf("queue disabled") + } + return false, true, nil + } + + e := r.getOrCreate(channelID) + e.mu.Lock() + mod := e.Mod + e.mu.Unlock() + if mod != nil && mod.Conn != nil { + if err := sendOverConn(mod.Conn, out); err == nil { + return true, false, nil + } + _ = mod.Conn.Close() + r.UnregisterMod(channelID) + } + + e.mu.Lock() + enq := e.Queue.Enqueue(out) + e.mu.Unlock() + if !enq { + return false, false, fmt.Errorf("queue disabled") + } + return false, true, nil +} + +// + +func (r *Registry) FlushChannelWithSender(channelID string, sendOverConn func(*websocket.Conn, GatewayMessageOut) error) { + r.mu.RLock() + e := r.entries[channelID] + r.mu.RUnlock() + if e == nil { + return + } + e.mu.Lock() + if e.Mod == nil || e.Mod.Conn == nil { + e.mu.Unlock() + return + } + msgs := e.Queue.PopAll() + modConn := e.Mod.Conn + e.mu.Unlock() + + for _, m := range msgs { + if err := sendOverConn(modConn, m); err != nil { + // if send fails, re-enqueue (best-effort), drop-oldest logic applies + e.mu.Lock() + _ = e.Queue.Enqueue(m) + e.mu.Unlock() + } + } +} + +func (r *Registry) FlushAllToBotWithSender(sendOverConn func(*websocket.Conn, GatewayMessageOut) error) { + r.botMu.Lock() + b := r.bot + r.botMu.Unlock() + if b == nil || b.Conn == nil { + return + } + + r.mu.RLock() + entries := make([]*ChannelEntry, 0, len(r.entries)) + for _, e := range r.entries { + entries = append(entries, e) + } + r.mu.RUnlock() + + for _, e := range entries { + e.mu.Lock() + msgs := e.Queue.PopAll() + e.mu.Unlock() + if len(msgs) == 0 { + continue + } + for _, m := range msgs { + if err := sendOverConn(b.Conn, m); err != nil { + e.mu.Lock() + _ = e.Queue.Enqueue(m) + e.mu.Unlock() + } + } + } +} diff --git a/ws/structs.go b/ws/structs.go index 96ea5b0..4623240 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -2,9 +2,8 @@ package ws import ( "encoding/json" - "homestead/homestead_gateway/util/cache" - "homestead/homestead_gateway/util/queue" "log/slog" + "sync" "time" "github.com/gorilla/websocket" @@ -17,16 +16,46 @@ type WebsocketGateway struct { upgrader websocket.Upgrader - cache *cache.Cache - queue *queue.Queue[GatewayMessageOut] - - bot *websocket.Conn - conns *cache.ConnectionCache + registry *Registry logger *slog.Logger closeFn func() error } +// + +type Registry struct { + mu sync.RWMutex + entries map[string]*ChannelEntry + queueCap int + + botMu sync.Mutex + bot *ConnWrapper +} + +type ConnWrapper struct { + Conn *websocket.Conn + ServerID string // set for mods (the server_id) + LastSeen time.Time +} + +type BoundedQueue struct { + mu sync.Mutex + buf []GatewayMessageOut + start int + length int + capacity int +} + +type ChannelEntry struct { + mu sync.Mutex + Channel string + Mod *ConnWrapper + Queue *BoundedQueue +} + +// + type User struct { ID string `json:"id"` Name string `json:"name"` @@ -64,15 +93,18 @@ type GatewayAck struct { Type string `json:"type"` } +// + type Handshake struct { Type string `json:"type"` // "mod" or "bot" Data json.RawMessage `json:"data"` } type ModHandshake struct { - ServerID string `json:"server_id"` + ServerID string `json:"server_id"` + ChannelID string `json:"channel_id"` } type BotHandshake struct { - BotID string `json:"bot_id"` + ChannelId string `json:"channel_id"` } diff --git a/ws/temp.go b/ws/temp.go deleted file mode 100644 index 86ca50b..0000000 --- a/ws/temp.go +++ /dev/null @@ -1,47 +0,0 @@ -package ws - -//type LoggingModHandler struct { -// logger *slog.Logger -//} -// -//type LoggingBotHandler struct { -// logger *slog.Logger -//} -// -//func NewLoggingModHandler(logger *slog.Logger) *LoggingModHandler { -// return &LoggingModHandler{logger: logger} -//} -// -//func NewLoggingBotHandler(logger *slog.Logger) *LoggingBotHandler { -// return &LoggingBotHandler{logger: logger} -//} -// -//func (h *LoggingModHandler) Handle(conn *websocket.Conn, msg GatewayMessageIn) error { -// // For now, just log and pretend it's being forwarded -// // TODO: Look up channel_id from database using server -// // TODO: Forward to bot connection(s) -// -// fwd := GatewayMessageOut{ -// Type: "mod", -// ChannelID: "TODO", // will come from database lookup -// Author: msg.Author, -// Content: msg.Content, -// Meta: msg.Meta, -// Ts: msg.Ts, -// ReceivedAt: msg.ReceivedAt, -// ForwardedAt: time.Now().UTC(), -// } -// -// _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) -// -// if err := conn.WriteJSON(fwd); err != nil { -// _ = conn.Close() -// return err -// } -// -// b, _ := json.Marshal(fwd) -// h.logger.Info("received mod message", "msg_id", msg.MsgID, "server", msg.Server, "Author", msg.Author.Name, "content", msg.Content) -// h.logger.Debug("forwarding mod message", "msg_id", msg.MsgID, "server", msg.Server, "payload", string(b)) -// -// return nil -//} diff --git a/ws/util.go b/ws/util.go index 6003517..2226d64 100644 --- a/ws/util.go +++ b/ws/util.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "errors" - "homestead/homestead_gateway/util/cache" "log/slog" "net/http" + "strings" "time" "github.com/gorilla/websocket" @@ -36,6 +36,11 @@ func (wsg *WebsocketGateway) deafen(srv *http.Server) { // responses +func (wsg *WebsocketGateway) flush(c *websocket.Conn, m GatewayMessageOut) error { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + return c.WriteJSON(m) +} + func (wsg *WebsocketGateway) sendHttpError(w http.ResponseWriter, message string, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) @@ -94,15 +99,6 @@ func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { return !(apiKey == "" || apiKey != wsg.apiKey) } -func writeJSONSafe(c *websocket.Conn, v interface{}) error { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteJSON(v); err != nil { - // caller handles logging - return err - } - return nil -} - func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -113,43 +109,112 @@ func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { // connections -func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) bool { +func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool { if typ == "bot" { - if wsg.bot != nil { + wsg.registry.botMu.Lock() + if wsg.registry.bot != nil && wsg.registry.bot.Conn != nil { + wsg.registry.botMu.Unlock() return false } + wsg.registry.botMu.Unlock() - wsg.bot = conn + wsg.registry.RegisterBot(conn) return true } - wsg.conns.Set(meta.ID, conn, &meta) + wsg.registry.RegisterMod(channelId, serverId, conn) return true } -func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, meta cache.ConnectionMetaData, typ string) { +func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, typ, channelId string) { if typ == "bot" { - _ = wsg.bot.Close() - wsg.bot = nil + wsg.registry.UnregisterBot() + if conn != nil { + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) + _ = conn.Close() + } return } - wsg.conns.RemoveById(meta.ID) - _ = conn.Close() + wsg.registry.UnregisterMod(channelId) + + if conn != nil { + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) + _ = conn.Close() + } } func (wsg *WebsocketGateway) closeAll() { wsg.logger.Info("Closing all websocket connections.") - if wsg.bot != nil { - _ = wsg.bot.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) - _ = wsg.bot.Close() + wsg.registry.UnregisterBot() + + wsg.registry.mu.RLock() + entries := make([]*ChannelEntry, 0, len(wsg.registry.entries)) + for _, e := range wsg.registry.entries { + entries = append(entries, e) + } + wsg.registry.mu.RUnlock() + + for _, e := range entries { + e.mu.Lock() + modConn := e.Mod + if modConn != nil { + e.Mod = nil + } + e.mu.Unlock() + + if modConn != nil && modConn.Conn != nil { + _ = modConn.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) + _ = modConn.Conn.Close() + } + } +} + +// + +func NewRegistry(queueCap int) *Registry { + return &Registry{ + entries: make(map[string]*ChannelEntry), + queueCap: queueCap, + } +} + +func NewBoundedQueue(cap int) *BoundedQueue { + if cap <= 0 { + cap = 128 + } + return &BoundedQueue{buf: make([]GatewayMessageOut, cap), capacity: cap} +} + +func newChannelEntry(channel string, cap int) *ChannelEntry { + return &ChannelEntry{ + Channel: channel, + Queue: NewBoundedQueue(cap), + } +} + +// + +func (m *GatewayMessageIn) Validate() error { + if strings.TrimSpace(m.ID) == "" { + return errors.New("id missing") + } + if strings.TrimSpace(m.MsgID) == "" { + return errors.New("msg_id missing") + } + if strings.TrimSpace(m.Author.ID) == "" { + return errors.New("author.id missing") + } + if strings.TrimSpace(m.Content) == "" { + return errors.New("content missing") } - wsg.conns.Range(func(id string, c *websocket.Conn, meta *cache.ConnectionMetaData) { - _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) - _ = c.Close() - }) + if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" { + return errors.New("destination.channel_id missing") + } - wsg.conns.Clear() + return nil } + +func (c *ConnWrapper) Alive() bool { return c != nil && c.Conn != nil } diff --git a/ws/validate.go b/ws/validate.go deleted file mode 100644 index 1431e47..0000000 --- a/ws/validate.go +++ /dev/null @@ -1,27 +0,0 @@ -package ws - -import ( - "errors" - "strings" -) - -func (m *GatewayMessageIn) Validate() error { - if strings.TrimSpace(m.ID) == "" { - return errors.New("id missing") - } - if strings.TrimSpace(m.MsgID) == "" { - return errors.New("msg_id missing") - } - if strings.TrimSpace(m.Author.ID) == "" { - return errors.New("author.id missing") - } - if strings.TrimSpace(m.Content) == "" { - return errors.New("content missing") - } - - if m.Type == "mod" && strings.TrimSpace(m.Destination.ID) == "" { - return errors.New("destination.channel_id missing") - } - - return nil -} diff --git a/ws/websocket.go b/ws/websocket.go index c9e8d64..e999b22 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -4,9 +4,7 @@ import ( "context" "encoding/json" "fmt" - "homestead/homestead_gateway/util/cache" "homestead/homestead_gateway/util/config" - "homestead/homestead_gateway/util/queue" "log/slog" "net" "net/http" @@ -19,13 +17,11 @@ import ( func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - port: cfg.HttpPort, - apiKey: cfg.Websocket, - cache: cache.NewCache(), - queue: queue.NewQueue[GatewayMessageOut](32), - conns: cache.NewConnectionCache(), + logger: logger, + closeFn: closefn, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + registry: NewRegistry(32), upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -184,9 +180,9 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // } //} -func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMetaData, _type string) { +func (wsg *WebsocketGateway) read(conn *websocket.Conn, _type, channelId string) { defer func() { - wsg.unregisterConn(conn, meta, _type) + wsg.unregisterConn(conn, _type, channelId) wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String()) }() @@ -235,70 +231,16 @@ func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMet continue } - var ok bool - var destConn *websocket.Conn - - switch message.Type { - case "mod": - if wsg.bot != nil { - ok = true - destConn = wsg.bot - } else { - ok = false - } - case "bot": - var id string - id, ok = wsg.cache.GetByChannelId(message.ID) - if ok { - var dest *websocket.Conn - dest, _, ok = wsg.conns.GetById(id) - if ok { - destConn = dest - } else { - wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) - wsg.logger.Error("Invalid cache structure.", "remote", conn.RemoteAddr().String(), "id", id) - return - } - } - default: - panic("invalid message type") - } - - if ok { - if destConn == nil { - wsg.sendWebsocketError(conn, "Internal Server Error", 500, true) - wsg.logger.Error("Destination connection unavailable.", "remote", conn.RemoteAddr().String()) - return - } - - err = wsg.sendWebsocketResponse(destConn, GatewayMessageOut{ - Type: message.Type, - ID: message.Destination.ID, - Author: message.Author, - Content: message.Content, - Meta: message.Meta, - Ts: message.Ts, - ReceivedAt: message.ReceivedAt, - ForwardedAt: time.Now().UTC(), - }) - - if err != nil { - _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) - wsg.logger.Error("Failed to forward message.", "remote", conn.RemoteAddr().String(), "err", err) - continue - } - - _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "completed", Type: message.Type}) - continue - } - - if message.Type == "mod" { - wsg.cache.Set(message.ID, message.Destination.ID) + var outID string + if _type == "mod" { + outID = channelId + } else { + outID = message.ID } out := GatewayMessageOut{ Type: message.Type, - ID: message.Destination.ID, + ID: outID, Author: message.Author, Content: message.Content, Meta: message.Meta, @@ -307,13 +249,28 @@ func (wsg *WebsocketGateway) read(conn *websocket.Conn, meta cache.ConnectionMet ForwardedAt: time.Now().UTC(), } - queued := wsg.queue.Enqueue(out) - if !queued { + delivered, queued, err := wsg.registry.Send(out.ID, out, func(c *websocket.Conn, m GatewayMessageOut) error { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + return c.WriteJSON(m) + }) + + if err != nil { + wsg.logger.Error("registry send error", "err", err) _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) - wsg.logger.Warn("Failed to queue message.", "remote", conn.RemoteAddr().String()) continue } - _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "queued", Type: message.Type}) + if delivered { + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "completed", Type: message.Type}) + continue + } + + if queued { + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "queued", Type: message.Type}) + continue + } + + _ = wsg.sendWebsocketResponse(conn, GatewayAck{Status: "failed", Type: message.Type}) + } } From 8f75e6491f6cfb8d154560ea06813a4aeace912b Mon Sep 17 00:00:00 2001 From: Overlord Date: Sun, 7 Dec 2025 17:21:20 +0100 Subject: [PATCH 12/15] qol change --- ws/registry.go | 54 ++++++++++++++++++++++++++++++++----------------- ws/util.go | 31 ++++++---------------------- ws/websocket.go | 2 +- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/ws/registry.go b/ws/registry.go index 64be288..263fcc4 100644 --- a/ws/registry.go +++ b/ws/registry.go @@ -79,22 +79,6 @@ func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) // caller should use FlushChannelWithSender to perform actual sends } -// UnregisterMod : remove mod for a channel -func (r *Registry) UnregisterMod(channelID string) { - r.mu.RLock() - e := r.entries[channelID] - r.mu.RUnlock() - if e == nil { - return - } - e.mu.Lock() - defer e.mu.Unlock() - if e.Mod != nil && e.Mod.Conn != nil { - _ = e.Mod.Conn.Close() - } - e.Mod = nil -} - // RegisterBot : single connection for bot. after registration call FlushAllToBotWithSender func (r *Registry) RegisterBot(conn *websocket.Conn) { r.botMu.Lock() @@ -105,13 +89,45 @@ func (r *Registry) RegisterBot(conn *websocket.Conn) { r.botMu.Unlock() } +func (r *Registry) UnregisterMod(channelID string) { + r.mu.RLock() + e := r.entries[channelID] + r.mu.RUnlock() + if e == nil { + return + } + + e.mu.Lock() + modConn := e.Mod + e.Mod = nil + e.mu.Unlock() + + if modConn != nil && modConn.Conn != nil { + _ = modConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) + _ = modConn.Conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), + time.Now().Add(time.Second), + ) + _ = modConn.Conn.Close() + } +} + func (r *Registry) UnregisterBot() { r.botMu.Lock() - if r.bot != nil && r.bot.Conn != nil { - _ = r.bot.Conn.Close() - } + botConn := r.bot r.bot = nil r.botMu.Unlock() + + if botConn != nil && botConn.Conn != nil { + _ = botConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) + _ = botConn.Conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), + time.Now().Add(time.Second), + ) + _ = botConn.Conn.Close() + } } func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn func(*websocket.Conn, GatewayMessageOut) error) (delivered bool, queued bool, err error) { diff --git a/ws/util.go b/ws/util.go index 2226d64..38e0ac1 100644 --- a/ws/util.go +++ b/ws/util.go @@ -126,22 +126,13 @@ func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, return true } -func (wsg *WebsocketGateway) unregisterConn(conn *websocket.Conn, typ, channelId string) { +func (wsg *WebsocketGateway) unregisterConn(typ, channelId string) { if typ == "bot" { wsg.registry.UnregisterBot() - if conn != nil { - _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) - _ = conn.Close() - } return } wsg.registry.UnregisterMod(channelId) - - if conn != nil { - _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), time.Now().Add(time.Second)) - _ = conn.Close() - } } func (wsg *WebsocketGateway) closeAll() { @@ -150,24 +141,14 @@ func (wsg *WebsocketGateway) closeAll() { wsg.registry.UnregisterBot() wsg.registry.mu.RLock() - entries := make([]*ChannelEntry, 0, len(wsg.registry.entries)) - for _, e := range wsg.registry.entries { - entries = append(entries, e) + channelIDs := make([]string, 0, len(wsg.registry.entries)) + for channelID := range wsg.registry.entries { + channelIDs = append(channelIDs, channelID) } wsg.registry.mu.RUnlock() - for _, e := range entries { - e.mu.Lock() - modConn := e.Mod - if modConn != nil { - e.Mod = nil - } - e.mu.Unlock() - - if modConn != nil && modConn.Conn != nil { - _ = modConn.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Shutting down."), time.Now().Add(time.Second)) - _ = modConn.Conn.Close() - } + for _, channelID := range channelIDs { + wsg.registry.UnregisterMod(channelID) } } diff --git a/ws/websocket.go b/ws/websocket.go index e999b22..f01cf80 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -182,7 +182,7 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error func (wsg *WebsocketGateway) read(conn *websocket.Conn, _type, channelId string) { defer func() { - wsg.unregisterConn(conn, _type, channelId) + wsg.unregisterConn(_type, channelId) wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String()) }() From 786b217f057b3c05503d63b7cf636d7f2055e0a9 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 8 Dec 2025 06:57:54 +0100 Subject: [PATCH 13/15] quality of life, updated config structure --- config.toml | 8 +-- util/config/structs.go | 12 +--- ws/registry.go | 47 +++++++++----- ws/structs.go | 20 +++--- ws/util.go | 36 +++++++---- ws/websocket.go | 135 +++-------------------------------------- 6 files changed, 77 insertions(+), 181 deletions(-) diff --git a/config.toml b/config.toml index b41f343..8850d7b 100644 --- a/config.toml +++ b/config.toml @@ -7,10 +7,4 @@ rotation = 3 # in days http_port = 3333 websocket = "gateway" body_size = 2 # in MB -queue_max = 8192 - -[database] -host_dsn = "" -username = "" -password = "" -database = "" +queue_max = 32 diff --git a/util/config/structs.go b/util/config/structs.go index 4700c84..5437791 100644 --- a/util/config/structs.go +++ b/util/config/structs.go @@ -3,9 +3,8 @@ package config import "log/slog" type Config struct { - Log LogConfig `toml:"log"` - Gateway GatewayConfig `toml:"gateway"` - Database DatabaseConfig `toml:"database"` + Log LogConfig `toml:"log"` + Gateway GatewayConfig `toml:"gateway"` } type GatewayConfig struct { @@ -20,10 +19,3 @@ type LogConfig struct { Directory string `toml:"directory"` Rotation int `toml:"rotation"` } - -type DatabaseConfig struct { - HostDSN string `toml:"host_dsn"` - Username string `toml:"username"` - Password string `toml:"password"` - Database string `toml:"database"` -} diff --git a/ws/registry.go b/ws/registry.go index 263fcc4..be8d6a5 100644 --- a/ws/registry.go +++ b/ws/registry.go @@ -13,14 +13,17 @@ func (q *BoundedQueue) Enqueue(m GatewayMessageOut) bool { if q.capacity == 0 { return false } + if q.length < q.capacity { q.buf[(q.start+q.length)%q.capacity] = m q.length++ return true } + // overwrite oldest q.buf[q.start] = m q.start = (q.start + 1) % q.capacity + return true } @@ -30,12 +33,15 @@ func (q *BoundedQueue) PopAll() []GatewayMessageOut { if q.length == 0 { return nil } + out := make([]GatewayMessageOut, 0, q.length) for i := 0; i < q.length; i++ { out = append(out, q.buf[(q.start+i)%q.capacity]) } + q.start = 0 q.length = 0 + return out } @@ -64,6 +70,19 @@ func (r *Registry) getOrCreate(channel string) *ChannelEntry { return e } +func (r *Registry) ForEach(cb func(channelID string)) { + r.mu.RLock() + ids := make([]string, 0, len(r.entries)) + for id := range r.entries { + ids = append(ids, id) + } + r.mu.RUnlock() + + for _, id := range ids { + cb(id) + } +} + // // RegisterMod : map channel_id -> mod conn (serverID) @@ -71,9 +90,12 @@ func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) e := r.getOrCreate(channelID) e.mu.Lock() defer e.mu.Unlock() + if e.Mod != nil && e.Mod.Conn != nil { + _ = e.Mod.Conn.Close() } + e.Mod = &ConnWrapper{Conn: conn, ServerID: serverID, LastSeen: time.Now()} // flush queued bot->mod messages for this channel // caller should use FlushChannelWithSender to perform actual sends @@ -82,9 +104,11 @@ func (r *Registry) RegisterMod(channelID, serverID string, conn *websocket.Conn) // RegisterBot : single connection for bot. after registration call FlushAllToBotWithSender func (r *Registry) RegisterBot(conn *websocket.Conn) { r.botMu.Lock() + if r.bot != nil && r.bot.Conn != nil { _ = r.bot.Conn.Close() } + r.bot = &ConnWrapper{Conn: conn, LastSeen: time.Now()} r.botMu.Unlock() } @@ -103,13 +127,7 @@ func (r *Registry) UnregisterMod(channelID string) { e.mu.Unlock() if modConn != nil && modConn.Conn != nil { - _ = modConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) - _ = modConn.Conn.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), - time.Now().Add(time.Second), - ) - _ = modConn.Conn.Close() + closeConn(modConn.Conn) } } @@ -120,13 +138,7 @@ func (r *Registry) UnregisterBot() { r.botMu.Unlock() if botConn != nil && botConn.Conn != nil { - _ = botConn.Conn.SetWriteDeadline(time.Now().Add(time.Second)) - _ = botConn.Conn.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), - time.Now().Add(time.Second), - ) - _ = botConn.Conn.Close() + closeConn(botConn.Conn) } } @@ -135,6 +147,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu r.botMu.Lock() b := r.bot r.botMu.Unlock() + if b != nil && b.Conn != nil { if err := sendOverConn(b.Conn, out); err == nil { return true, false, nil @@ -147,6 +160,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu e.mu.Lock() enq := e.Queue.Enqueue(out) e.mu.Unlock() + if !enq { return false, false, fmt.Errorf("queue disabled") } @@ -157,6 +171,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu e.mu.Lock() mod := e.Mod e.mu.Unlock() + if mod != nil && mod.Conn != nil { if err := sendOverConn(mod.Conn, out); err == nil { return true, false, nil @@ -168,6 +183,7 @@ func (r *Registry) Send(channelID string, out GatewayMessageOut, sendOverConn fu e.mu.Lock() enq := e.Queue.Enqueue(out) e.mu.Unlock() + if !enq { return false, false, fmt.Errorf("queue disabled") } @@ -183,11 +199,13 @@ func (r *Registry) FlushChannelWithSender(channelID string, sendOverConn func(*w if e == nil { return } + e.mu.Lock() if e.Mod == nil || e.Mod.Conn == nil { e.mu.Unlock() return } + msgs := e.Queue.PopAll() modConn := e.Mod.Conn e.mu.Unlock() @@ -221,6 +239,7 @@ func (r *Registry) FlushAllToBotWithSender(sendOverConn func(*websocket.Conn, Ga e.mu.Lock() msgs := e.Queue.PopAll() e.mu.Unlock() + if len(msgs) == 0 { continue } diff --git a/ws/structs.go b/ws/structs.go index 4623240..f601775 100644 --- a/ws/structs.go +++ b/ws/structs.go @@ -67,19 +67,19 @@ type Destination struct { type GatewayMessageIn struct { Type string - ID string `json:"id"` // where am I from (channel_id or server_id) - MsgID string `json:"msg_id"` // msg id - Destination Destination `json:"destination,omitempty"` // where do I wanna go (channel_id or empty if from Bot) - Author User `json:"author"` // who sent the message - Content string `json:"content"` // message content - Meta map[string]interface{} `json:"meta,omitempty"` // additional metadata - Ts time.Time `json:"ts,omitempty"` // timestamp - ReceivedAt time.Time `json:"-"` // ReceivedAt is populated by gateway (not from mod) + ID string `json:"id"` + MsgID string `json:"msg_id"` + Destination Destination `json:"destination,omitempty"` + Author User `json:"author"` + Content string `json:"content"` + Meta map[string]interface{} `json:"meta,omitempty"` + Ts time.Time `json:"ts,omitempty"` + ReceivedAt time.Time `json:"-"` } type GatewayMessageOut struct { - Type string `json:"type"` // "mod"|"bot" - ID string `json:"channel_id,omitempty"` // message.Destination.ID + Type string `json:"type"` + ID string `json:"channel_id,omitempty"` Author User `json:"author"` Content string `json:"content"` Meta map[string]interface{} `json:"meta,omitempty"` diff --git a/ws/util.go b/ws/util.go index 38e0ac1..3e9a8b7 100644 --- a/ws/util.go +++ b/ws/util.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "log/slog" "net/http" "strings" "time" @@ -99,16 +98,26 @@ func (wsg *WebsocketGateway) validateApiKey(r *http.Request) bool { return !(apiKey == "" || apiKey != wsg.apiKey) } -func loggingMiddleware(logger *slog.Logger, next http.Handler) http.Handler { +func (wsg *WebsocketGateway) loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) - logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start)) + wsg.logger.Info("Incoming HTTP request.", "remote", r.RemoteAddr, "path", r.URL.Path, "duration", time.Since(start)) }) } // connections +func closeConn(conn *websocket.Conn) { + _ = conn.SetWriteDeadline(time.Now().Add(time.Second)) + _ = conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Disconnecting."), + time.Now().Add(time.Second), + ) + _ = conn.Close() +} + func (wsg *WebsocketGateway) registerConn(conn *websocket.Conn, typ, channelId, serverId string) bool { if typ == "bot" { wsg.registry.botMu.Lock() @@ -140,20 +149,23 @@ func (wsg *WebsocketGateway) closeAll() { wsg.registry.UnregisterBot() - wsg.registry.mu.RLock() - channelIDs := make([]string, 0, len(wsg.registry.entries)) - for channelID := range wsg.registry.entries { - channelIDs = append(channelIDs, channelID) - } - wsg.registry.mu.RUnlock() - - for _, channelID := range channelIDs { + wsg.registry.ForEach(func(channelID string) { wsg.registry.UnregisterMod(channelID) - } + }) } // +func NewUpgrader() websocket.Upgrader { + return websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // local by default; change for production + }, + } +} + func NewRegistry(queueCap int) *Registry { return &Registry{ entries: make(map[string]*ChannelEntry), diff --git a/ws/websocket.go b/ws/websocket.go index f01cf80..d392fac 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -17,18 +17,12 @@ import ( func NewWebsocketGateway(cfg config.GatewayConfig, logger *slog.Logger, closefn func() error) *WebsocketGateway { return &WebsocketGateway{ - logger: logger, - closeFn: closefn, - port: cfg.HttpPort, - apiKey: cfg.Websocket, - registry: NewRegistry(32), - upgrader: websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true // local by default; change for production - }, - }, + logger: logger, + closeFn: closefn, + port: cfg.HttpPort, + apiKey: cfg.Websocket, + upgrader: NewUpgrader(), + registry: NewRegistry(cfg.QueueSize), bodySizeBytes: int64(cfg.BodySize) * 1024 * 1024, } } @@ -47,7 +41,7 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error srv := &http.Server{ Addr: listenAddr, - Handler: loggingMiddleware(wsg.logger, mux), + Handler: wsg.loggingMiddleware(mux), BaseContext: func(l net.Listener) context.Context { return ctx }, } errCh := make(chan error, 1) @@ -65,121 +59,6 @@ func (wsg *WebsocketGateway) Serve(ctx context.Context, listenAddr string) error // -//func (wsg *WebsocketGateway) modReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { -// defer func() { -// wsg.unregisterConn(conn, meta, "mod") -// wsg.logger.Info("Client disconnected.", "remote", conn.RemoteAddr().String(), "server_id", meta.ID) -// }() -// -// ticker := time.NewTicker(30 * time.Second) -// defer ticker.Stop() -// -// go func() { -// for range ticker.C { -// wsg.sendWebsocketPing(conn) -// } -// }() -// -// for { -// typ, data, err := conn.ReadMessage() -// -// if err != nil { -// if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { -// wsg.logger.Warn("Mod-Client unexpectedly closed the connection.", "err", err) -// } -// return -// } -// -// if typ != websocket.TextMessage && typ != websocket.BinaryMessage { -// continue -// } -// -// var msg GatewayModMessageIn -// if err := json.Unmarshal(data, &msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) -// wsg.logger.Warn("invalid json from mod", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// msg.ReceivedAt = time.Now().UTC() -// if err := msg.Validate(); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) -// wsg.logger.Warn("mod message validation failed", "server_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// // Handle the message (forward to bot, enrich, etc.) -// if err := wsg.modHandler.Handle(conn, msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) -// wsg.logger.Error("mod handler error", "server_id", meta.ID, "err", err) -// continue -// } -// -// _ = writeJSONSafe(conn, map[string]string{"status": "completed"}) // or "queued" -// } -//} -// -//func (wsg *WebsocketGateway) botReadLoop(conn *websocket.Conn, meta cache.ConnectionMetaData) { -// defer func() { -// wsg.unregisterConn(conn, meta, "bot") -// wsg.logger.Info("bot disconnected", "bot_id", meta.ID, "remote", conn.RemoteAddr().String()) -// }() -// -// pingTicker := time.NewTicker(30 * time.Second) -// defer pingTicker.Stop() -// -// // Send pings in a separate goroutine -// go func() { -// for range pingTicker.C { -// _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) -// if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { -// wsg.logger.Debug("write ping failed", "bot_id", meta.ID, "err", err) -// return -// } -// wsg.logger.Debug("sent ping to bot", "bot_id", meta.ID) -// } -// }() -// -// for { -// typ, data, err := conn.ReadMessage() -// if err != nil { -// if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { -// wsg.logger.Warn("unexpected bot close", "bot_id", meta.ID, "err", err) -// } else { -// wsg.logger.Debug("bot read error", "bot_id", meta.ID, "err", err) -// } -// return -// } -// -// if typ != websocket.TextMessage && typ != websocket.BinaryMessage { -// continue -// } -// -// var msg GatewayBotMessageIn -// if err := json.Unmarshal(data, &msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "invalid json: " + err.Error()}) -// wsg.logger.Warn("invalid json from bot", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// msg.ReceivedAt = time.Now().UTC() -// if err := msg.Validate(); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": err.Error()}) -// wsg.logger.Warn("bot message validation failed", "bot_id", meta.ID, "remote", conn.RemoteAddr().String(), "err", err) -// continue -// } -// -// // Handle the message (forward to mod, enrich, etc.) -// if err := wsg.botHandler.Handle(conn, msg); err != nil { -// _ = writeJSONSafe(conn, map[string]string{"error": "handler error: " + err.Error()}) -// wsg.logger.Error("bot handler error", "bot_id", meta.ID, "err", err) -// continue -// } -// -// _ = writeJSONSafe(conn, map[string]string{"status": "ok"}) -// } -//} - func (wsg *WebsocketGateway) read(conn *websocket.Conn, _type, channelId string) { defer func() { wsg.unregisterConn(_type, channelId) From 8e84523d4bfd7a07f083e35910fe81ef91d46868 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 8 Dec 2025 06:58:43 +0100 Subject: [PATCH 14/15] continuous send for simulators --- bot.go | 120 ++++++++++++++++++++++++++++++++++++++-------------- sim.go | 131 +++++++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 188 insertions(+), 63 deletions(-) diff --git a/bot.go b/bot.go index b4069ae..8dd5130 100644 --- a/bot.go +++ b/bot.go @@ -5,10 +5,14 @@ package main import ( "encoding/json" "flag" + "fmt" "log" + "math/rand" "net/url" "os" "os/signal" + "sync" + "sync/atomic" "syscall" "time" @@ -18,8 +22,12 @@ import ( const ( gatewayURL = "ws://localhost:3333/sync" apiKey = "gateway" - // must match the mod's channel id used by your mod simulator - channelID = "123456789" + channelID = "123456789" +) + +var ( + minInterval = 500 * time.Millisecond + maxInterval = 5 * time.Second ) type Handshake struct { @@ -28,7 +36,7 @@ type Handshake struct { } type BotHandshake struct { - ChannelId string `json:"channel_id"` // match gateway field exactly + ChannelId string `json:"channel_id"` } type GatewayAck struct { @@ -56,12 +64,21 @@ type GatewayMessageIn struct { ReceivedAt time.Time `json:"-"` } +func randomDuration(min, max time.Duration) time.Duration { + if max <= min { + return min + } + diff := int64(max - min) + n := rand.Int63n(diff) + return min + time.Duration(n) +} + func main() { - var ( - botID = flag.String("bot", "sim-bot-1", "bot id") - sendAfter = flag.Duration("send-after", 0, "optional: send a bot->mod message after this delay (e.g. 2s)") - sendMsg = flag.String("msg", "Hello from bot!", "optional bot->mod test message content") - ) + rand.Seed(time.Now().UnixNano()) + + botID := flag.String("bot", "sim-bot-1", "bot id") + sendAfter := flag.Duration("send-after", 0, "optional: send a bot->mod message after this delay (e.g. 2s)") + sendMsg := flag.String("msg", "Hello from bot!", "optional bot->mod test message content") flag.Parse() interrupt := make(chan os.Signal, 1) @@ -75,21 +92,25 @@ func main() { q.Set("api_key", apiKey) u.RawQuery = q.Encode() - log.Printf("Connecting to %s", u.String()) conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { log.Fatalf("Failed to connect: %v", err) } - // we intentionally don't defer conn.Close() here immediately; we'll do on shutdown - log.Println("Connected to gateway") + defer conn.Close() + + var writeMu sync.Mutex - // handle server pings by replying a Pong (safe) conn.SetPingHandler(func(appData string) error { - log.Println("Received ping from server, sending pong") - return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) + writeMu.Lock() + err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) + writeMu.Unlock() + if err != nil { + log.Printf("Failed to send pong: %v", err) + return err + } + return nil }) - // send bot handshake (must at least include bot_id) bhs := BotHandshake{ChannelId: channelID} data, err := json.Marshal(bhs) if err != nil { @@ -97,21 +118,22 @@ func main() { log.Fatalf("Failed to marshal bot handshake: %v", err) } hs := Handshake{Type: "bot", Data: data} + + writeMu.Lock() if err := conn.WriteJSON(hs); err != nil { + writeMu.Unlock() _ = conn.Close() log.Fatalf("Failed to send handshake: %v", err) } - log.Println("Handshake sent (bot)") + writeMu.Unlock() - // Instead of ReadJSON, read raw message so we can see whatever the server returns (or why it closes). - // This avoids a silent parsing failure if server returns non-JSON or closes immediately. conn.SetReadDeadline(time.Now().Add(10 * time.Second)) msgType, raw, err := conn.ReadMessage() if err != nil { _ = conn.Close() log.Fatalf("Failed to read handshake response: %v", err) } - _ = conn.SetReadDeadline(time.Time{}) // clear deadline + _ = conn.SetReadDeadline(time.Time{}) if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { log.Printf("Raw handshake reply: %s", string(raw)) @@ -122,11 +144,11 @@ func main() { log.Printf("Parsed ack: status=%q, type=%q", ack.Status, ack.Type) } } else { - log.Printf("Handshake reply was control frame or unexpected type=%d", msgType) + log.Printf("Handshake reply type=%d", msgType) } - // From here, start the normal read loop. done := make(chan struct{}) + go func() { defer close(done) for { @@ -139,20 +161,20 @@ func main() { } return } - if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { log.Printf("Received from gateway: %s", string(message)) } } }() - // optionally send a bot->mod message after a delay + var msgCounter uint64 = 1 + if *sendAfter > 0 { go func() { time.Sleep(*sendAfter) msg := GatewayMessageIn{ - MsgID: "bot-msg-001", - ID: channelID, // bot reports channel id as ID + MsgID: fmt.Sprintf("bot-msg-%06d", atomic.AddUint64(&msgCounter, 1)), + ID: channelID, Author: User{ ID: *botID, Name: "SimBot", @@ -160,20 +182,55 @@ func main() { Content: *sendMsg, Ts: time.Now().UTC(), } - // set a write deadline + writeMu.Lock() _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := conn.WriteJSON(msg); err != nil { log.Printf("Failed to send bot->mod test message: %v", err) - return + } else { + log.Printf("Sent bot->mod test message (channel=%s)", channelID) } - log.Printf("Sent bot->mod test message (channel=%s)", channelID) _ = conn.SetWriteDeadline(time.Time{}) + writeMu.Unlock() }() } - log.Println("Bot simulator running. Press Ctrl+C to exit.") + go func() { + for { + select { + case <-done: + return + default: + } + d := randomDuration(minInterval, maxInterval) + select { + case <-done: + return + case <-time.After(d): + msgNum := atomic.AddUint64(&msgCounter, 1) + msg := GatewayMessageIn{ + MsgID: fmt.Sprintf("sim-bot-msg-%06d", msgNum), + ID: channelID, + Author: User{ + ID: fmt.Sprintf("%s-%d", *botID, msgNum%1000), + Name: fmt.Sprintf("SimBot%d", msgNum%1000), + }, + Content: fmt.Sprintf("Automated bot message #%d (delay %s)", msgNum, d), + Ts: time.Now().UTC(), + } + writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := conn.WriteJSON(msg); err != nil { + writeMu.Unlock() + log.Printf("Failed to send automated bot message: %v", err) + return + } + _ = conn.SetWriteDeadline(time.Time{}) + writeMu.Unlock() + log.Printf("Sent automated bot message %s", msg.MsgID) + } + } + }() - // Wait for interrupt or read loop done for { select { case <-done: @@ -182,8 +239,9 @@ func main() { return case <-interrupt: log.Println("Interrupt received, closing connection...") - // politely close + writeMu.Lock() _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + writeMu.Unlock() select { case <-done: case <-time.After(time.Second): diff --git a/sim.go b/sim.go index ce2a105..ccc4aeb 100644 --- a/sim.go +++ b/sim.go @@ -4,10 +4,14 @@ package main import ( "encoding/json" + "fmt" "log" + "math/rand" "net/url" "os" "os/signal" + "sync" + "sync/atomic" "syscall" "time" @@ -22,6 +26,12 @@ const ( channelID = "123456789" ) +// send interval range (random between minInterval and maxInterval) +var ( + minInterval = 500 * time.Millisecond + maxInterval = 5 * time.Second +) + type Handshake struct { Type string `json:"type"` Data json.RawMessage `json:"data"` @@ -59,6 +69,8 @@ type GatewayMessageIn struct { } func main() { + rand.Seed(time.Now().UnixNano()) + interrupt := make(chan os.Signal, 1) signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) @@ -80,19 +92,19 @@ func main() { log.Println("Connected to gateway") - // respond to pings (server ping -> client must pong). Using SetPingHandler is fine, - // but WriteControl for Pong is acceptable too. + var writeMu sync.Mutex + conn.SetPingHandler(func(appData string) error { log.Println("Received ping from server, sending pong") - err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)) - if err != nil { + writeMu.Lock() + defer writeMu.Unlock() + if err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(5*time.Second)); err != nil { log.Printf("Failed to send pong: %v", err) return err } return nil }) - // Build and send handshake including channel id modHS := ModHandshake{ ServerID: serverID, ChannelID: channelID, @@ -107,12 +119,14 @@ func main() { Data: modHSData, } + writeMu.Lock() if err := conn.WriteJSON(handshake); err != nil { + writeMu.Unlock() log.Fatalf("Failed to send handshake: %v", err) } + writeMu.Unlock() log.Println("Handshake sent") - // Read acknowledgment (some servers might not reply with JSON; handle errors) var ack GatewayAck if err := conn.ReadJSON(&ack); err != nil { log.Fatalf("Failed to read acknowledgment: %v", err) @@ -129,62 +143,115 @@ func main() { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { log.Printf("WebSocket error: %v", err) } else { - log.Printf("Connection closed: %v", err) + log.Printf("Connection closed/read error: %v", err) } return } - if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { + switch messageType { + case websocket.TextMessage, websocket.BinaryMessage: log.Printf("Received from server: %s", string(message)) + default: } } }() - // Optional: send a test message after connecting - time.Sleep(1 * time.Second) - testMsg := GatewayMessageIn{ - MsgID: "test-msg-001", - ID: serverID, - Destination: Destination{ - ID: channelID, - }, - Author: User{ - ID: "player-uuid-123", - Name: "TestPlayer", - }, - Content: "Hello from simulated mod!", - Ts: time.Now().UTC(), - } + var msgCounter uint64 = 1 - if err := conn.WriteJSON(testMsg); err != nil { - log.Printf("Failed to send test message: %v", err) - } else { - log.Println("Sent test message to gateway") - } + func() { + testMsg := GatewayMessageIn{ + MsgID: fmt.Sprintf("test-msg-%06d", atomic.AddUint64(&msgCounter, 1)), + ID: serverID, + Destination: Destination{ + ID: channelID, + }, + Author: User{ + ID: "player-uuid-123", + Name: "TestPlayer", + }, + Content: "Hello from simulated mod!", + Ts: time.Now().UTC(), + } - log.Println("Connection established. Responding to pings. Press Ctrl+C to disconnect.") + writeMu.Lock() + if err := conn.WriteJSON(testMsg); err != nil { + log.Printf("Failed to send test message: %v", err) + } else { + log.Println("Sent initial test message to gateway") + } + writeMu.Unlock() + }() + + go func() { + for { + d := randomDuration(minInterval, maxInterval) + + select { + case <-done: + return + case <-time.After(d): + // build message + msgNum := atomic.AddUint64(&msgCounter, 1) + msg := GatewayMessageIn{ + MsgID: fmt.Sprintf("sim-msg-%06d", msgNum), + ID: serverID, + Destination: Destination{ + ID: channelID, + }, + Author: User{ + ID: fmt.Sprintf("sim-user-%d", msgNum%1000), + Name: fmt.Sprintf("SimUser%d", msgNum%1000), + }, + Content: fmt.Sprintf("Random interval message #%d (delay %s)", msgNum, d), + Ts: time.Now().UTC(), + } + + writeMu.Lock() + if err := conn.WriteJSON(msg); err != nil { + log.Printf("Failed to send simulated message: %v", err) + writeMu.Unlock() + return + } + writeMu.Unlock() + + log.Printf("Sent simulated message %s (next wait up to %s)", msg.MsgID, maxInterval) + } + } + }() + + log.Println("Connection established. Sending simulated messages at random intervals. Press Ctrl+C to disconnect.") for { select { case <-done: - log.Println("Connection closed") + log.Println("Connection read loop closed, exiting") return case <-interrupt: log.Println("Interrupt received, closing connection...") + writeMu.Lock() err := conn.WriteMessage( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), ) + writeMu.Unlock() if err != nil { log.Printf("Write close error: %v", err) - return } select { case <-done: - case <-time.After(time.Second): + case <-time.After(1 * time.Second): } return } } } + +func randomDuration(min, max time.Duration) time.Duration { + if max <= min { + return min + } + diff := int64(max - min) + n := rand.Int63n(diff) + return min + time.Duration(n) +} From e198fc4b3f98ce7e1338d9f8968d9602dbb9d542 Mon Sep 17 00:00:00 2001 From: Overlord Date: Mon, 8 Dec 2025 07:01:05 +0100 Subject: [PATCH 15/15] updated .gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index bf9d731..7c68b5e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +* +!*/ +!*.* + *.exe *.exe~ *.dll