import { WebSocketServer, WebSocket } from "ws"; import http from "node:http"; import { ConnectionManager } from "./connections.js"; import { generateChallenge, verifyAuthResponse } from "./auth.js"; import { verify, buildSignablePayload } from "../crypto/signing.js"; import { PairingStore } from "../pairing/pairing.js"; import { type Envelope, type MessageType, MESSAGE_TYPES, PING_INTERVAL_MS, PONG_TIMEOUT_MS, } from "../protocol/spec.js"; export interface RelayServerConfig { port: number; host?: string; } interface PendingAuth { challenge: Buffer; createdAt: number; } /** * CookieBridge Relay Server. * * HTTP endpoints: * POST /pair — initiate a pairing session * POST /pair/accept — accept a pairing session * GET /health — health check * * WebSocket: * /ws — authenticated device connection for message relay */ export class RelayServer { private httpServer: http.Server; private wss: WebSocketServer; private connections: ConnectionManager; private pairingStore: PairingStore; private pendingAuths = new Map(); private authenticatedDevices = new Map(); // ws -> deviceId private pingIntervals = new Map>(); constructor(private config: RelayServerConfig) { this.connections = new ConnectionManager(); this.pairingStore = new PairingStore(); this.httpServer = http.createServer(this.handleHttp.bind(this)); this.wss = new WebSocketServer({ server: this.httpServer }); this.wss.on("connection", this.handleConnection.bind(this)); } start(): Promise { return new Promise((resolve) => { this.httpServer.listen( this.config.port, this.config.host ?? "0.0.0.0", () => resolve(), ); }); } stop(): Promise { return new Promise((resolve) => { for (const interval of this.pingIntervals.values()) { clearInterval(interval); } this.wss.close(() => { this.httpServer.close(() => resolve()); }); }); } get port(): number { const addr = this.httpServer.address(); if (addr && typeof addr === "object") return addr.port; return this.config.port; } // --- HTTP --- private handleHttp(req: http.IncomingMessage, res: http.ServerResponse): void { if (req.method === "GET" && req.url === "/health") { res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify({ status: "ok", connections: this.connections.connectedCount })); return; } if (req.method === "POST" && req.url === "/pair") { this.handlePairCreate(req, res); return; } if (req.method === "POST" && req.url === "/pair/accept") { this.handlePairAccept(req, res); return; } res.writeHead(404); res.end("Not found"); } private handlePairCreate(req: http.IncomingMessage, res: http.ServerResponse): void { this.readBody(req, (body) => { try { const { deviceId, x25519PubKey } = JSON.parse(body); if (!deviceId || !x25519PubKey) { res.writeHead(400, { "Content-Type": "application/json" }); res.end(JSON.stringify({ error: "Missing deviceId or x25519PubKey" })); return; } const session = this.pairingStore.create(deviceId, x25519PubKey); res.writeHead(201, { "Content-Type": "application/json" }); res.end(JSON.stringify({ pairingCode: session.pairingCode, expiresAt: session.expiresAt })); } catch { res.writeHead(400, { "Content-Type": "application/json" }); res.end(JSON.stringify({ error: "Invalid JSON" })); } }); } private handlePairAccept(req: http.IncomingMessage, res: http.ServerResponse): void { this.readBody(req, (body) => { try { const { deviceId, x25519PubKey, pairingCode } = JSON.parse(body); if (!deviceId || !x25519PubKey || !pairingCode) { res.writeHead(400, { "Content-Type": "application/json" }); res.end(JSON.stringify({ error: "Missing required fields" })); return; } const session = this.pairingStore.consume(pairingCode); if (!session) { res.writeHead(404, { "Content-Type": "application/json" }); res.end(JSON.stringify({ error: "Invalid or expired pairing code" })); return; } // Return both peers' info res.writeHead(200, { "Content-Type": "application/json" }); res.end( JSON.stringify({ initiator: { deviceId: session.deviceId, x25519PubKey: session.x25519PubKey, }, acceptor: { deviceId, x25519PubKey, }, }), ); } catch { res.writeHead(400, { "Content-Type": "application/json" }); res.end(JSON.stringify({ error: "Invalid JSON" })); } }); } private readBody(req: http.IncomingMessage, cb: (body: string) => void): void { let data = ""; req.on("data", (chunk: Buffer) => { data += chunk.toString(); if (data.length > 64 * 1024) { req.destroy(); } }); req.on("end", () => cb(data)); } // --- WebSocket --- private handleConnection(ws: WebSocket): void { // Send auth challenge const challenge = generateChallenge(); this.pendingAuths.set(ws, { challenge, createdAt: Date.now() }); ws.send(JSON.stringify({ type: "auth_challenge", challenge: challenge.toString("hex") })); ws.on("message", (data: Buffer) => { this.handleMessage(ws, data); }); ws.on("close", () => { this.handleDisconnect(ws); }); ws.on("error", () => { this.handleDisconnect(ws); }); // Auth timeout — disconnect if not authenticated within 10s setTimeout(() => { if (this.pendingAuths.has(ws)) { ws.close(4000, "Auth timeout"); this.pendingAuths.delete(ws); } }, 10_000); } private handleMessage(ws: WebSocket, data: Buffer): void { let msg: Record; try { msg = JSON.parse(data.toString()); } catch { ws.send(JSON.stringify({ type: "error", error: "Invalid JSON" })); return; } // Handle auth response if (msg.type === "auth_response") { this.handleAuthResponse(ws, msg); return; } // All other messages require authentication const deviceId = this.authenticatedDevices.get(ws); if (!deviceId) { ws.send(JSON.stringify({ type: "error", error: "Not authenticated" })); return; } // Handle ping/pong if (msg.type === MESSAGE_TYPES.PING) { ws.send(JSON.stringify({ type: MESSAGE_TYPES.PONG })); return; } // Handle relay messages if ( msg.type === MESSAGE_TYPES.COOKIE_SYNC || msg.type === MESSAGE_TYPES.COOKIE_DELETE || msg.type === MESSAGE_TYPES.ACK ) { this.handleRelayMessage(ws, deviceId, msg as unknown as Envelope); return; } ws.send(JSON.stringify({ type: "error", error: "Unknown message type" })); } private handleAuthResponse(ws: WebSocket, msg: Record): void { const pending = this.pendingAuths.get(ws); if (!pending) { ws.send(JSON.stringify({ type: "error", error: "No pending auth challenge" })); return; } const { deviceId, sig } = msg as { deviceId: string; sig: string }; if (!deviceId || !sig) { ws.close(4002, "Invalid auth response"); return; } const sigBuf = Buffer.from(sig, "hex"); const pubBuf = Buffer.from(deviceId, "hex"); if (!verifyAuthResponse(pending.challenge, sigBuf, pubBuf)) { ws.close(4003, "Auth failed"); this.pendingAuths.delete(ws); return; } // Authenticated this.pendingAuths.delete(ws); this.authenticatedDevices.set(ws, deviceId); this.connections.register(deviceId, ws); ws.send(JSON.stringify({ type: "auth_ok", deviceId })); // Start ping interval const interval = setInterval(() => { if (ws.readyState === 1) { ws.send(JSON.stringify({ type: MESSAGE_TYPES.PING })); } }, PING_INTERVAL_MS); this.pingIntervals.set(ws, interval); } private handleRelayMessage(ws: WebSocket, fromDeviceId: string, envelope: Envelope): void { // Verify the 'from' matches the authenticated device if (envelope.from !== fromDeviceId) { ws.send(JSON.stringify({ type: "error", error: "Sender mismatch" })); return; } // Verify signature const signable = buildSignablePayload({ type: envelope.type, from: envelope.from, to: envelope.to, nonce: envelope.nonce, payload: envelope.payload, timestamp: envelope.timestamp, }); const sigBuf = Buffer.from(envelope.sig, "hex"); const pubBuf = Buffer.from(fromDeviceId, "hex"); if (!verify(signable, sigBuf, pubBuf)) { ws.send(JSON.stringify({ type: "error", error: "Invalid signature" })); return; } // Route to recipient const delivered = this.connections.send(envelope.to, envelope); // Acknowledge to sender ws.send( JSON.stringify({ type: MESSAGE_TYPES.ACK, ref: envelope.nonce, delivered, }), ); } private handleDisconnect(ws: WebSocket): void { const deviceId = this.authenticatedDevices.get(ws); if (deviceId) { this.connections.remove(deviceId, ws); this.authenticatedDevices.delete(ws); } this.pendingAuths.delete(ws); const interval = this.pingIntervals.get(ws); if (interval) { clearInterval(interval); this.pingIntervals.delete(ws); } } }