diff --git a/modules/web.py b/modules/web.py index dffa568..45534ba 100644 --- a/modules/web.py +++ b/modules/web.py @@ -1,154 +1,268 @@ -import multiprocessing -import json, signal +# websocket_module.py +import multiprocessing, os +import json import threading, uuid, time -from functools import partial -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn +import asyncio +import websockets +from websockets import ServerConnection + +# keep these imports/names so this file can be used in the same place as your original from . import Track, PlayerModule, Path MAIN_PATH_DIR = Path("/home/user/mixes") -class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): - """Handle requests in a separate thread.""" +# ---------- WebSocket server process ---------- +# This runs in a separate process. It uses the manager.dict() (shared) for reads +# and uses imc_q to send control messages to the main process. +# +# The Module will place broadcast messages onto ws_q (a manager.Queue) which +# this server reads and forwards to all connected clients. -class APIHandler(BaseHTTPRequestHandler): - def __init__(self, data, imc_q, *args, **kwargs): - self.data = data - self.imc_q = imc_q - super().__init__(*args, **kwargs) +async def ws_handler(websocket: ServerConnection, shared_data: dict, imc_q: multiprocessing.Queue, ws_q: multiprocessing.Queue): + """ + Per-connection handler. Accepts JSON messages from clients and responds. + Also sends initial state on connect. + """ + # send initial state + try: + # shared_data stores JSON strings like before; try to send parsed objects + initial = { + "playlist": json.loads(shared_data.get("playlist", "[]")), + "track": json.loads(shared_data.get("track", "{}")), + "progress": json.loads(shared_data.get("progress", "{}")), + } + except Exception: + initial = {"playlist": [], "track": {}, "progress": {}} + await websocket.send(json.dumps({"event": "initial_state", "data": initial})) - def do_GET(self): - code = 200 + async for raw in websocket: + try: + msg = json.loads(raw) + except Exception: + await websocket.send(json.dumps({"error": "invalid json"})) + continue - if self.path == "/api/playlist": rdata = json.loads(self.data.get("playlist", "[]")) - elif self.path == "/api/track": rdata = json.loads(self.data.get("track", "{}")) - elif self.path == "/api/progress": rdata = json.loads(self.data.get("progress", "{}")) - elif self.path == "/api/put": - id = str(uuid.uuid4()) - self.imc_q.put({"name": "activemod", "data": {"action": "get_toplay"}, "key": id}) - start_time = time.monotonic() - response_json = None - while time.monotonic() - start_time < 2: - if id in self.data: - response_json = self.data.pop(id) - break - time.sleep(0.05) - if response_json: - try: - rdata = response_json - if "error" in repr(rdata): code = 500 - except TypeError: - rdata = {"error": "Invalid data format from module"} - code = 500 + # Simple control actions + action = msg.get("action") + if action == "skip": + imc_q.put({"name": "procman", "data": {"op": 2}}) + await websocket.send(json.dumps({"status": "ok", "action": "skip_requested"})) + elif action == "add_to_toplay": + songs = msg.get("songs") + if not isinstance(songs, list): + await websocket.send(json.dumps({"error": "songs must be a list"})) else: - rdata = {"error": "Request to active module timed out"} - code = 504 # Gateway Timeout - elif self.path == "/api/dirs": - rdata = {"base": str(MAIN_PATH_DIR), "files": [i.name for i in list(MAIN_PATH_DIR.iterdir())]} - elif self.path.startswith("/api/dir/"): - rdata = [i.name for i in (MAIN_PATH_DIR / self.path.removeprefix("/api/dir/").removesuffix("/")).iterdir() if i.is_file()] - else: rdata = {"error": "not found"} + imc_q.put({"name": "activemod", "data": {"action": "add_to_toplay", "songs": songs}}) + await websocket.send(json.dumps({"status": "ok", "message": f"{len(songs)} song(s) queued"})) + elif action == "get_toplay": + # replicate the previous behavior: send request to activemod and wait for keyed response + key = str(uuid.uuid4()) + imc_q.put({"name": "activemod", "data": {"action": "get_toplay"}, "key": key}) + # wait up to 2 seconds for shared_data[key] to appear + start = time.monotonic() + result = None + while time.monotonic() - start < 2: + if key in shared_data: + result = shared_data.pop(key) + break + await asyncio.sleep(0.05) + if result is None: + await websocket.send(json.dumps({"error": "timeout", "code": 504})) + else: + await websocket.send(json.dumps({"status": "ok", "response": result})) + elif action == "request_state": + # supports requesting specific parts if provided + what = msg.get("what") + try: + if what == "playlist": + payload = json.loads(shared_data.get("playlist", "[]")) + elif what == "track": + payload = json.loads(shared_data.get("track", "{}")) + elif what == "progress": + payload = json.loads(shared_data.get("progress", "{}")) + else: + payload = { + "playlist": json.loads(shared_data.get("playlist", "[]")), + "track": json.loads(shared_data.get("track", "{}")), + "progress": json.loads(shared_data.get("progress", "{}")), + } + except Exception: + payload = {} + await websocket.send(json.dumps({"event": "state", "data": payload})) + else: + await websocket.send(json.dumps({"error": "unknown action"})) - self.send_response(code) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps(rdata).encode('utf-8')) - def do_POST(self): - response = {"error": "not found"} - code = 404 +async def broadcast_worker(shared_data: dict, ws_q: multiprocessing.Queue, clients: set): + """ + Reads messages from ws_q (a blocking multiprocessing.Queue) using run_in_executor + and broadcasts them to all connected clients. + """ + loop = asyncio.get_event_loop() + while True: + # blocking get executed in default threadpool so we don't block the event loop + msg = await loop.run_in_executor(None, ws_q.get) + if msg is None: + # sentinel to shut down + break + # msg expected to be serializable (e.g. {"event": "playlist", "data": ...}) + payload = json.dumps(msg) + # send concurrently; ignore per-client errors (client may disconnect) + if clients: + coros = [] + for ws in list(clients): + coros.append(_safe_send(ws, payload, clients)) + await asyncio.gather(*coros) - if self.path == "/api/skip": - self.imc_q.put({"name": "procman", "data": {"op": 2}}) - response = {"status": "ok", "action": "skip requested"} - code = 200 - elif self.path == "/api/put": - try: - body = json.loads(self.rfile.read(int(self.headers['Content-Length']))) - - songs = body.get("songs") - if songs is None or not isinstance(songs, list): raise ValueError("Request body must be a JSON object with a 'songs' key containing a list of strings.") - self.imc_q.put({"name": "activemod", "data": {"action": "add_to_toplay", "songs": songs}}) - - response = {"status": "ok", "message": f"{len(songs)} song(s) were added to the high-priority queue."} - code = 200 - except json.JSONDecodeError: - response = {"error": "Invalid JSON in request body."} - code = 400 - except (ValueError, KeyError, TypeError) as e: - response = {"error": f"Invalid request format: {e}"} - code = 400 - except Exception as e: - response = {"error": f"An unexpected server error occurred: {e}"} - code = 500 +async def _safe_send(ws, payload: str, clients: set): + try: + await ws.send(payload) + except Exception: + # remove dead websocket + try: + clients.discard(ws) + except Exception: + pass - self.send_response(code) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps(response).encode('utf-8')) - def send_response(self, code, message=None): - self.send_response_only(code, message) - self.send_header('Server', self.version_string()) - self.send_header('Date', self.date_time_string()) -def web_server_process(data, imc_q): - def signal_handler(sig, frame): pass - signal.signal(signal.SIGINT, signal_handler) - ThreadingHTTPServer(("0.0.0.0", 3001), partial(APIHandler, data, imc_q)).serve_forever() +def websocket_server_process(shared_data: dict, imc_q: multiprocessing.Queue, ws_q: multiprocessing.Queue): + """ + Entrypoint for the separate process that runs the asyncio-based websocket server. + """ + # create the asyncio loop and run server + async def runner(): + clients = set() + + async def handler_wrapper(websocket: ServerConnection): + # register client + clients.add(websocket) + try: + await ws_handler(websocket, shared_data, imc_q, ws_q) + finally: + clients.discard(websocket) + + # start server + server = await websockets.serve(handler_wrapper, "0.0.0.0", 3001) + # background task: broadcast worker + broadcaster = asyncio.create_task(broadcast_worker(shared_data, ws_q, clients)) + # run forever until server closes + await server.wait_closed() + # ensure broadcaster stops + ws_q.put(None) + await broadcaster + + # On SIGINT/SIGTERM, stop gracefully + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(runner()) + except (KeyboardInterrupt, SystemExit): + pass + finally: + loop.close() + +# ---------- Module class (drop-in replacement) ---------- class Module(PlayerModule): def __init__(self): self.manager = multiprocessing.Manager() self.data = self.manager.dict() self.imc_q = self.manager.Queue() + # queue for sending broadcasts to the websocket process + self.ws_q = self.manager.Queue() + # initial state self.data["playlist"] = "[]" self.data["track"] = "{}" self.data["progress"] = "{}" + # ipc thread: listens for responses from other modules (same as before) self.ipc_thread_running = True self.ipc_thread = threading.Thread(target=self._ipc_worker, daemon=True) self.ipc_thread.start() - self.web_process = multiprocessing.Process(target=web_server_process, args=(self.data, self.imc_q)) - self.web_process.start() + # start websocket server process + self.ws_process = multiprocessing.Process( + target=websocket_server_process, args=(self.data, self.imc_q, self.ws_q), daemon=False + ) + self.ws_process.start() + if os.name == "posix": + try: + os.setpgid(self.ws_process.pid, self.ws_process.pid) + except Exception: + pass def _ipc_worker(self): + """ + Listens for messages placed in imc_q by websocket process or other modules, + forwards them to the main IPC layer and stores keyed responses into shared dict. + """ while self.ipc_thread_running: try: message: dict | None = self.imc_q.get() if message is None: break + # send to upper layer (existing player IPC) out = self._imc.send(self, message["name"], message["data"]) - if key := message.get("key", None): self.data[key] = out - except Exception: pass + # if message had a key, store the response for the requester + if key := message.get("key", None): + # store response into shared dict (accessible to ws process) + self.data[key] = out + except Exception: + # swallow errors to avoid killing the ipc thread + pass + # The following functions update the shared_data and also push a broadcast message onto ws_q def on_new_playlist(self, playlist: list[Track]) -> None: api_data = [] for track in playlist: - api_data.append({"path": str(track.path), "fade_out": track.fade_out, "fade_in": track.fade_in, "official": track.official, "args": track.args, "offset": track.offset}) + api_data.append({ + "path": str(track.path), + "fade_out": track.fade_out, + "fade_in": track.fade_in, + "official": track.official, + "args": track.args, + "offset": track.offset + }) self.data["playlist"] = json.dumps(api_data) + # broadcast + try: self.ws_q.put({"event": "playlist", "data": api_data}) + except Exception: pass def on_new_track(self, index: int, track: Track, next_track: Track | None) -> None: track_data = {"path": str(track.path), "fade_out": track.fade_out, "fade_in": track.fade_in, "official": track.official, "args": track.args, "offset": track.offset} - if next_track: - next_track_data = {"path": str(next_track.path), "fade_out": next_track.fade_out, "fade_in": next_track.fade_in, "official": next_track.official, "args": next_track.args, "offset": next_track.offset} + if next_track: next_track_data = {"path": str(next_track.path), "fade_out": next_track.fade_out, "fade_in": next_track.fade_in, "official": next_track.official, "args": next_track.args, "offset": next_track.offset} else: next_track_data = None - self.data["track"] = json.dumps({"index": index, "track": track_data, "next_track": next_track_data}) - + payload = {"index": index, "track": track_data, "next_track": next_track_data} + self.data["track"] = json.dumps(payload) + try: self.ws_q.put({"event": "new_track", "data": payload}) + except Exception: pass + def progress(self, index: int, track: Track, elapsed: float, total: float, real_total: float) -> None: track_data = {"path": str(track.path), "fade_out": track.fade_out, "fade_in": track.fade_in, "official": track.official, "args": track.args, "offset": track.offset} - self.data["progress"] = json.dumps({"index": index, "track": track_data, "elapsed": elapsed, "total": total, "real_total": real_total}) + payload = {"index": index, "track": track_data, "elapsed": elapsed, "total": total, "real_total": real_total} + self.data["progress"] = json.dumps(payload) + # For frequent progress updates you might want to rate-limit; this pushes every call + try: self.ws_q.put({"event": "progress", "data": payload}) + except Exception: pass def shutdown(self): + # stop ipc thread self.ipc_thread_running = False - self.imc_q.put(None) + try: self.imc_q.put(None) + except Exception: pass self.ipc_thread.join(timeout=2) - if self.web_process.is_alive(): - self.web_process.terminate() - self.web_process.join(timeout=2) + # shutdown websocket process by putting sentinel into ws_q and then terminating if needed + try: self.ws_q.put(None) + except Exception: pass - if self.web_process.is_alive(): self.web_process.kill() + if self.ws_process.is_alive(): + self.ws_process.terminate() + self.ws_process.join(timeout=2) -module = Module() \ No newline at end of file + if self.ws_process.is_alive(): + try: self.ws_process.kill() + except Exception: pass + +module = Module()