Files
storyforge/deploy/storyforge-windows-asr-http/app.py
kris a048bd26b1
Some checks failed
StoryForge CI / Baseline checks (push) Has been cancelled
StoryForge CI / Backend tests (push) Has been cancelled
StoryForge CI / Web tests (push) Has been cancelled
feat: move asr to windows and disable local model
2026-04-06 10:20:39 +08:00

91 lines
2.8 KiB
Python

from __future__ import annotations
import os
import tempfile
import time
from functools import lru_cache
from pathlib import Path
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
LANGUAGE = os.getenv("WHISPER_LANGUAGE", "zh")
DEVICE = os.getenv("WHISPER_DEVICE", "cpu")
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
BEAM_SIZE = int(os.getenv("WHISPER_BEAM_SIZE", "5"))
VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "1").strip().lower() not in {"0", "false", "no"}
DOWNLOAD_ROOT = Path(os.getenv("WHISPER_DOWNLOAD_ROOT", str(Path(__file__).resolve().parent / "models-cache")))
app = FastAPI(title="storyforge-windows-asr", version="1.0.0")
@lru_cache(maxsize=1)
def get_model():
from faster_whisper import WhisperModel
DOWNLOAD_ROOT.mkdir(parents=True, exist_ok=True)
return WhisperModel(
MODEL_NAME,
device=DEVICE,
compute_type=COMPUTE_TYPE,
download_root=str(DOWNLOAD_ROOT),
)
@app.get("/health")
def health() -> dict[str, object]:
return {
"status": "ok",
"service": "storyforge-windows-asr",
"model_name": MODEL_NAME,
"language": LANGUAGE,
"device": DEVICE,
"compute_type": COMPUTE_TYPE,
"download_root": str(DOWNLOAD_ROOT),
"model_loaded": get_model.cache_info().currsize > 0,
}
@app.get("/")
def root() -> dict[str, str]:
return {"service": "storyforge-windows-asr", "docs": "/docs"}
@app.post("/transcribe", response_model=None)
async def transcribe(wav: UploadFile = File(...)):
started = time.perf_counter()
suffix = Path(wav.filename or "segment.wav").suffix or ".wav"
with tempfile.NamedTemporaryFile(prefix="storyforge-asr-", suffix=suffix, delete=False) as handle:
temp_path = Path(handle.name)
handle.write(await wav.read())
try:
model = get_model()
segments, _info = model.transcribe(
str(temp_path),
language=LANGUAGE or None,
beam_size=max(1, BEAM_SIZE),
vad_filter=VAD_FILTER,
)
text = "".join(segment.text for segment in segments).strip()
duration_ms = int((time.perf_counter() - started) * 1000)
return {
"text": text,
"success": bool(text),
"duration_ms": duration_ms,
"error_message": None if text else "empty transcription",
}
except Exception as exc:
return JSONResponse(
status_code=500,
content={
"text": "",
"success": False,
"duration_ms": int((time.perf_counter() - started) * 1000),
"error_message": str(exc),
},
)
finally:
temp_path.unlink(missing_ok=True)