summaryrefslogtreecommitdiff
path: root/wss/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'wss/main.go')
-rwxr-xr-xwss/main.go275
1 files changed, 206 insertions, 69 deletions
diff --git a/wss/main.go b/wss/main.go
index a3ea063..5c270ba 100755
--- a/wss/main.go
+++ b/wss/main.go
@@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"log"
+ "math"
"net/http"
"os"
"sync"
@@ -13,11 +14,25 @@ import (
"github.com/gorilla/websocket"
)
+const (
+ // Time allowed to write a message to the peer
+ writeWait = 10 * time.Second
+
+ // Time allowed to read the next pong message from the peer
+ pongWait = 60 * time.Second
+
+ // Send pings to peer with this period (must be less than pongWait)
+ pingPeriod = (pongWait * 9) / 10
+
+ // Maximum message size allowed from peer
+ maxMessageSize = 4096
+)
+
// Configuration for the server
type Config struct {
- Debug bool
- Port int
- StaticDir string
+ Debug bool
+ Port int
+ StaticDir string
}
// Environment variables to pass to frontend
@@ -76,18 +91,39 @@ type Hub struct {
// Message types
const (
- MessageTypeAddNode = "addNode"
- MessageTypeAddEdge = "addEdge"
- MessageTypeRemoveNode = "removeNode"
- MessageTypeFullSync = "fullSync"
+ MessageTypeAddNode = "addNode"
+ MessageTypeAddEdge = "addEdge"
+ MessageTypeRemoveNode = "removeNode"
+ MessageTypeRemoveEdge = "removeEdge"
+ MessageTypeMoveNode = "moveNode"
+ MessageTypeFullSync = "fullSync"
+ MessageTypeConnectionCount = "connectionCount"
+ MessageTypeError = "error"
)
// WebSocket message format
type WebSocketMessage struct {
- Type string `json:"type"`
- Node *Node `json:"node,omitempty"`
- Edge *Edge `json:"edge,omitempty"`
- Graph *Graph `json:"graph,omitempty"`
+ Type string `json:"type"`
+ Node *Node `json:"node,omitempty"`
+ Edge *Edge `json:"edge,omitempty"`
+ Graph *Graph `json:"graph,omitempty"`
+ Count int `json:"count,omitempty"`
+ Message string `json:"message,omitempty"`
+}
+
+// findNodeByID returns the index of a node by ID, or -1 if not found
+func (g *Graph) findNodeByID(id int) int {
+ for i, n := range g.Nodes {
+ if n.ID == id {
+ return i
+ }
+ }
+ return -1
+}
+
+// hasNodeID returns true if a node with the given ID exists
+func (g *Graph) hasNodeID(id int) bool {
+ return g.findNodeByID(id) >= 0
}
var hub = Hub{
@@ -119,6 +155,30 @@ func createInitialGraph() Graph {
return Graph{Nodes: nodes, Edges: edges}
}
+// broadcastToAll sends a pre-marshaled message to all connections
+func (h *Hub) broadcastToAll(data []byte) {
+ for conn := range h.connections {
+ select {
+ case conn.send <- data:
+ default:
+ close(conn.send)
+ delete(h.connections, conn)
+ }
+ }
+}
+
+// broadcastConnectionCount sends the current connection count to all clients
+func (h *Hub) broadcastConnectionCount() {
+ msg := WebSocketMessage{
+ Type: MessageTypeConnectionCount,
+ Count: len(h.connections),
+ }
+ data, err := json.Marshal(msg)
+ if err == nil {
+ h.broadcastToAll(data)
+ }
+}
+
func (h *Hub) run() {
for {
select {
@@ -126,24 +186,28 @@ func (h *Hub) run() {
h.connections[conn] = true
log.Printf("Client connected. Total connections: %d", len(h.connections))
- // Send current graph state to new connection
+ // Send current graph state to new connection (marshal under lock)
h.graphMutex.RLock()
fullSync := WebSocketMessage{
Type: MessageTypeFullSync,
Graph: &h.graph,
}
+ data, err := json.Marshal(fullSync)
h.graphMutex.RUnlock()
- data, err := json.Marshal(fullSync)
if err == nil {
conn.send <- data
}
+ // Notify all clients of connection count
+ h.broadcastConnectionCount()
+
case conn := <-h.unregister:
if _, ok := h.connections[conn]; ok {
delete(h.connections, conn)
close(conn.send)
log.Printf("Client disconnected. Total connections: %d", len(h.connections))
+ h.broadcastConnectionCount()
}
case message := <-h.broadcast:
@@ -154,40 +218,40 @@ func (h *Hub) run() {
continue
}
- response := h.handleOperation(msg)
- if response == nil {
- continue
- }
-
- // Broadcast the response to all connections
- data, err := json.Marshal(response)
- if err != nil {
- log.Printf("Error marshaling response: %v", err)
+ // handleOperation now returns pre-marshaled bytes (serialized under lock)
+ data := h.handleOperation(msg)
+ if data == nil {
continue
}
- for conn := range h.connections {
- select {
- case conn.send <- data:
- default:
- close(conn.send)
- delete(h.connections, conn)
- }
- }
+ h.broadcastToAll(data)
}
}
}
-// handleOperation processes graph operations and returns the message to broadcast
-func (h *Hub) handleOperation(msg WebSocketMessage) *WebSocketMessage {
+// isValidCoord checks that a coordinate is a finite number
+func isValidCoord(v float64) bool {
+ return !math.IsNaN(v) && !math.IsInf(v, 0)
+}
+
+// handleOperation processes graph operations and returns pre-marshaled JSON bytes.
+// Marshaling happens under the graph lock to prevent TOCTOU races.
+func (h *Hub) handleOperation(msg WebSocketMessage) []byte {
h.graphMutex.Lock()
defer h.graphMutex.Unlock()
+ var response *WebSocketMessage
+
switch msg.Type {
case MessageTypeAddNode:
if msg.Node == nil {
return nil
}
+ // Validate coordinates
+ if !isValidCoord(msg.Node.X) || !isValidCoord(msg.Node.Y) || !isValidCoord(msg.Node.Z) {
+ log.Printf("Rejected addNode: invalid coordinates")
+ return nil
+ }
// Assign server-side ID to avoid conflicts
newID := 0
for _, n := range h.graph.Nodes {
@@ -203,16 +267,16 @@ func (h *Hub) handleOperation(msg WebSocketMessage) *WebSocketMessage {
}
h.graph.Nodes = append(h.graph.Nodes, node)
log.Printf("Added node %d at (%.2f, %.2f, %.2f)", newID, node.X, node.Y, node.Z)
- return &WebSocketMessage{Type: MessageTypeAddNode, Node: &node}
+ response = &WebSocketMessage{Type: MessageTypeAddNode, Node: &node}
case MessageTypeAddEdge:
if msg.Edge == nil {
return nil
}
- // Validate edge
- if msg.Edge.From < 0 || msg.Edge.From >= len(h.graph.Nodes) ||
- msg.Edge.To < 0 || msg.Edge.To >= len(h.graph.Nodes) ||
+ // Validate edge endpoints exist by ID (not index)
+ if !h.graph.hasNodeID(msg.Edge.From) || !h.graph.hasNodeID(msg.Edge.To) ||
msg.Edge.From == msg.Edge.To {
+ log.Printf("Rejected addEdge %d->%d: invalid node IDs", msg.Edge.From, msg.Edge.To)
return nil
}
// Check for duplicate
@@ -225,16 +289,21 @@ func (h *Hub) handleOperation(msg WebSocketMessage) *WebSocketMessage {
edge := Edge{From: msg.Edge.From, To: msg.Edge.To}
h.graph.Edges = append(h.graph.Edges, edge)
log.Printf("Added edge %d -> %d", edge.From, edge.To)
- return &WebSocketMessage{Type: MessageTypeAddEdge, Edge: &edge}
+ response = &WebSocketMessage{Type: MessageTypeAddEdge, Edge: &edge}
case MessageTypeRemoveNode:
- if msg.Node == nil || msg.Node.ID < 0 || msg.Node.ID >= len(h.graph.Nodes) {
+ if msg.Node == nil {
return nil
}
- // For simplicity, mark node as removed by setting special coordinates
- // A full implementation would handle ID remapping
nodeID := msg.Node.ID
- // Remove edges connected to this node
+ idx := h.graph.findNodeByID(nodeID)
+ if idx < 0 {
+ log.Printf("Rejected removeNode %d: not found", nodeID)
+ return nil
+ }
+ // Remove the node from the slice
+ h.graph.Nodes = append(h.graph.Nodes[:idx], h.graph.Nodes[idx+1:]...)
+ // Remove all edges connected to this node
newEdges := []Edge{}
for _, e := range h.graph.Edges {
if e.From != nodeID && e.To != nodeID {
@@ -242,17 +311,67 @@ func (h *Hub) handleOperation(msg WebSocketMessage) *WebSocketMessage {
}
}
h.graph.Edges = newEdges
- log.Printf("Removed node %d", nodeID)
+ log.Printf("Removed node %d and its edges", nodeID)
// Send full sync after removal for simplicity
- return &WebSocketMessage{Type: MessageTypeFullSync, Graph: &h.graph}
+ response = &WebSocketMessage{Type: MessageTypeFullSync, Graph: &h.graph}
+
+ case MessageTypeRemoveEdge:
+ if msg.Edge == nil {
+ return nil
+ }
+ // Find and remove the edge
+ found := false
+ newEdges := []Edge{}
+ for _, e := range h.graph.Edges {
+ if (e.From == msg.Edge.From && e.To == msg.Edge.To) ||
+ (e.From == msg.Edge.To && e.To == msg.Edge.From) {
+ found = true
+ continue
+ }
+ newEdges = append(newEdges, e)
+ }
+ if !found {
+ return nil
+ }
+ h.graph.Edges = newEdges
+ log.Printf("Removed edge %d -> %d", msg.Edge.From, msg.Edge.To)
+ response = &WebSocketMessage{Type: MessageTypeRemoveEdge, Edge: msg.Edge}
+
+ case MessageTypeMoveNode:
+ if msg.Node == nil {
+ return nil
+ }
+ if !isValidCoord(msg.Node.X) || !isValidCoord(msg.Node.Y) || !isValidCoord(msg.Node.Z) {
+ return nil
+ }
+ // Find and update the node by ID
+ idx := h.graph.findNodeByID(msg.Node.ID)
+ if idx < 0 {
+ return nil
+ }
+ h.graph.Nodes[idx].X = msg.Node.X
+ h.graph.Nodes[idx].Y = msg.Node.Y
+ h.graph.Nodes[idx].Z = msg.Node.Z
+ movedNode := h.graph.Nodes[idx]
+ response = &WebSocketMessage{Type: MessageTypeMoveNode, Node: &movedNode}
case "reset":
- h.graph = Graph{Nodes: []Node{}, Edges: []Edge{}}
+ h.graph = createInitialGraph()
log.Printf("Graph reset")
- return &WebSocketMessage{Type: MessageTypeFullSync, Graph: &h.graph}
+ response = &WebSocketMessage{Type: MessageTypeFullSync, Graph: &h.graph}
}
- return nil
+ if response == nil {
+ return nil
+ }
+
+ // Marshal while still holding the lock to avoid TOCTOU
+ data, err := json.Marshal(response)
+ if err != nil {
+ log.Printf("Error marshaling response: %v", err)
+ return nil
+ }
+ return data
}
// WebSocket handler
@@ -263,72 +382,76 @@ func serveWebSocket(w http.ResponseWriter, r *http.Request) {
log.Printf("WebSocket upgrade error: %v", err)
return
}
-
+
// Create a new connection
conn := &Connection{
ws: ws,
send: make(chan []byte, 256),
}
-
+
// Register the connection with the hub
hub.register <- conn
-
+
// Start the connection handlers
go conn.writer()
conn.reader()
}
-// Writer goroutine for connection
+// Writer goroutine for connection — sends messages and periodic pings
func (c *Connection) writer() {
- // Ensure clean close of connection
+ ticker := time.NewTicker(pingPeriod)
defer func() {
+ ticker.Stop()
c.ws.Close()
}()
-
+
for {
select {
case message, ok := <-c.send:
+ c.ws.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.ws.WriteMessage(websocket.CloseMessage, []byte{})
return
}
-
- // Write the message
if err := c.ws.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
+ case <-ticker.C:
+ c.ws.SetWriteDeadline(time.Now().Add(writeWait))
+ if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil {
+ return
+ }
}
}
}
-// Reader goroutine for connection
+// Reader goroutine for connection — reads messages and forwards to hub
func (c *Connection) reader() {
- // Ensure clean close of connection and unregister
defer func() {
hub.unregister <- c
c.ws.Close()
}()
-
- // Set read deadline
- c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
+
+ // Set message size limit
+ c.ws.SetReadLimit(maxMessageSize)
+ c.ws.SetReadDeadline(time.Now().Add(pongWait))
c.ws.SetPongHandler(func(string) error {
- c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
+ c.ws.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
-
- // Read messages from the connection
+
for {
_, message, err := c.ws.ReadMessage()
if err != nil {
+ if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
+ log.Printf("WebSocket error: %v", err)
+ }
break
}
-
- // Broadcast the message to all connections
hub.broadcast <- message
}
}
-
func main() {
// Parse command line flags
config := Config{}
@@ -345,15 +468,29 @@ func main() {
// Start the hub
go hub.run()
- // Create a file server for static files (serve from parent directory)
- fs := http.FileServer(http.Dir(config.StaticDir))
+ // Create a file server for static files (scoped to prevent directory traversal)
+ staticFS := http.Dir(config.StaticDir)
+ fs := http.FileServer(staticFS)
- // Set up routes
+ // Set up routes — only serve allowed static directories
+ allowedPrefixes := []string{"/css/", "/js/", "/doc/"}
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path == "/" {
+ if r.URL.Path == "/" || r.URL.Path == "/index.html" {
http.ServeFile(w, r, config.StaticDir+"/index.html")
return
}
+ // Only serve files from allowed directories
+ allowed := false
+ for _, prefix := range allowedPrefixes {
+ if len(r.URL.Path) >= len(prefix) && r.URL.Path[:len(prefix)] == prefix {
+ allowed = true
+ break
+ }
+ }
+ if !allowed {
+ http.NotFound(w, r)
+ return
+ }
fs.ServeHTTP(w, r)
})