from __future__ import annotations import os import tempfile import time from functools import lru_cache from pathlib import Path from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse MODEL_NAME = os.getenv("WHISPER_MODEL", "base") LANGUAGE = os.getenv("WHISPER_LANGUAGE", "zh") DEVICE = os.getenv("WHISPER_DEVICE", "cpu") COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8") BEAM_SIZE = int(os.getenv("WHISPER_BEAM_SIZE", "5")) VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "1").strip().lower() not in {"0", "false", "no"} DOWNLOAD_ROOT = Path(os.getenv("WHISPER_DOWNLOAD_ROOT", str(Path(__file__).resolve().parent / "models-cache"))) app = FastAPI(title="storyforge-windows-asr", version="1.0.0") @lru_cache(maxsize=1) def get_model(): from faster_whisper import WhisperModel DOWNLOAD_ROOT.mkdir(parents=True, exist_ok=True) return WhisperModel( MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE, download_root=str(DOWNLOAD_ROOT), ) @app.get("/health") def health() -> dict[str, object]: return { "status": "ok", "service": "storyforge-windows-asr", "model_name": MODEL_NAME, "language": LANGUAGE, "device": DEVICE, "compute_type": COMPUTE_TYPE, "download_root": str(DOWNLOAD_ROOT), "model_loaded": get_model.cache_info().currsize > 0, } @app.get("/") def root() -> dict[str, str]: return {"service": "storyforge-windows-asr", "docs": "/docs"} @app.post("/transcribe", response_model=None) async def transcribe(wav: UploadFile = File(...)): started = time.perf_counter() suffix = Path(wav.filename or "segment.wav").suffix or ".wav" with tempfile.NamedTemporaryFile(prefix="storyforge-asr-", suffix=suffix, delete=False) as handle: temp_path = Path(handle.name) handle.write(await wav.read()) try: model = get_model() segments, _info = model.transcribe( str(temp_path), language=LANGUAGE or None, beam_size=max(1, BEAM_SIZE), vad_filter=VAD_FILTER, ) text = "".join(segment.text for segment in segments).strip() duration_ms = int((time.perf_counter() - started) * 1000) return { "text": text, "success": bool(text), "duration_ms": duration_ms, "error_message": None if text else "empty transcription", } except Exception as exc: return JSONResponse( status_code=500, content={ "text": "", "success": False, "duration_ms": int((time.perf_counter() - started) * 1000), "error_message": str(exc), }, ) finally: temp_path.unlink(missing_ok=True)