208 lines
6.7 KiB
Python
208 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
import sysconfig
|
|
import tempfile
|
|
import time
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, File, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
|
|
MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
|
|
BEAM_SIZE = int(os.getenv("WHISPER_BEAM_SIZE", "5"))
|
|
VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "1").strip().lower() not in {"0", "false", "no"}
|
|
DOWNLOAD_ROOT = Path(os.getenv("WHISPER_DOWNLOAD_ROOT", str(Path(__file__).resolve().parent / "models-cache")))
|
|
|
|
app = FastAPI(title="storyforge-windows-asr", version="1.0.0")
|
|
_dll_handles: list[object] = []
|
|
|
|
|
|
def describe_language_mode() -> str:
|
|
value = (os.getenv("WHISPER_LANGUAGE", "") or "").strip()
|
|
if not value or value.lower() in {"auto", "detect"}:
|
|
return "auto"
|
|
return value
|
|
|
|
|
|
def resolve_language() -> str | None:
|
|
value = describe_language_mode()
|
|
return None if value == "auto" else value
|
|
|
|
|
|
def describe_device_mode() -> str:
|
|
value = (os.getenv("WHISPER_DEVICE", "") or "").strip().lower()
|
|
return value or "auto"
|
|
|
|
|
|
def describe_compute_mode() -> str:
|
|
value = (os.getenv("WHISPER_COMPUTE_TYPE", "") or "").strip()
|
|
return value or "auto"
|
|
|
|
|
|
def build_runtime_profiles() -> list[tuple[str, str]]:
|
|
fallback_profile = getattr(app.state, "runtime_fallback_profile", None)
|
|
if fallback_profile:
|
|
return [fallback_profile]
|
|
device = describe_device_mode()
|
|
compute = describe_compute_mode()
|
|
if device != "auto":
|
|
return [(device, compute if compute != "auto" else "int8")]
|
|
if compute != "auto":
|
|
return [("cuda", compute), ("cpu", compute)]
|
|
return [("cuda", "int8_float16"), ("cpu", "int8")]
|
|
|
|
|
|
def should_retry_on_cpu(exc: Exception) -> bool:
|
|
if describe_device_mode() != "auto":
|
|
return False
|
|
message = str(exc).lower()
|
|
return any(token in message for token in ("cublas", "cudnn", "cuda"))
|
|
|
|
|
|
def activate_cpu_fallback() -> None:
|
|
app.state.runtime_fallback_profile = ("cpu", "int8")
|
|
app.state.runtime_device = "cpu"
|
|
app.state.runtime_compute_type = "int8"
|
|
get_model.cache_clear()
|
|
|
|
|
|
def find_windows_cuda_runtime_dirs(site_packages_root: Path | None = None) -> list[Path]:
|
|
root = site_packages_root or Path(sysconfig.get_paths()["purelib"])
|
|
dirs = []
|
|
for rel in (
|
|
"nvidia/cublas/bin",
|
|
"nvidia/cuda_runtime/bin",
|
|
"nvidia/cuda_nvrtc/bin",
|
|
"nvidia/cudnn/bin",
|
|
):
|
|
path = root / rel
|
|
if path.exists():
|
|
dirs.append(path)
|
|
return dirs
|
|
|
|
|
|
def configure_windows_cuda_runtime() -> None:
|
|
if sys.platform != "win32":
|
|
return
|
|
configured = getattr(app.state, "windows_cuda_runtime_dirs", None)
|
|
if configured is not None:
|
|
return
|
|
runtime_dirs = find_windows_cuda_runtime_dirs()
|
|
app.state.windows_cuda_runtime_dirs = [str(path) for path in runtime_dirs]
|
|
if not runtime_dirs:
|
|
return
|
|
path_parts = os.environ.get("PATH", "").split(os.pathsep)
|
|
for runtime_dir in runtime_dirs:
|
|
runtime_dir_str = str(runtime_dir)
|
|
if runtime_dir_str not in path_parts:
|
|
path_parts.insert(0, runtime_dir_str)
|
|
if hasattr(os, "add_dll_directory"):
|
|
_dll_handles.append(os.add_dll_directory(runtime_dir_str))
|
|
os.environ["PATH"] = os.pathsep.join(path_parts)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_model():
|
|
configure_windows_cuda_runtime()
|
|
from faster_whisper import WhisperModel
|
|
|
|
DOWNLOAD_ROOT.mkdir(parents=True, exist_ok=True)
|
|
last_error: Exception | None = None
|
|
for device, compute_type in build_runtime_profiles():
|
|
try:
|
|
model = WhisperModel(
|
|
MODEL_NAME,
|
|
device=device,
|
|
compute_type=compute_type,
|
|
download_root=str(DOWNLOAD_ROOT),
|
|
)
|
|
app.state.runtime_device = device
|
|
app.state.runtime_compute_type = compute_type
|
|
return model
|
|
except Exception as exc: # pragma: no cover - exercised on real hosts
|
|
last_error = exc
|
|
assert last_error is not None
|
|
raise last_error
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict[str, object]:
|
|
configure_windows_cuda_runtime()
|
|
return {
|
|
"status": "ok",
|
|
"service": "storyforge-windows-asr",
|
|
"model_name": MODEL_NAME,
|
|
"language": describe_language_mode(),
|
|
"device": describe_device_mode(),
|
|
"compute_type": describe_compute_mode(),
|
|
"active_device": getattr(app.state, "runtime_device", ""),
|
|
"active_compute_type": getattr(app.state, "runtime_compute_type", ""),
|
|
"download_root": str(DOWNLOAD_ROOT),
|
|
"model_loaded": get_model.cache_info().currsize > 0,
|
|
"windows_cuda_runtime_dirs": getattr(app.state, "windows_cuda_runtime_dirs", []),
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
def root() -> dict[str, str]:
|
|
return {"service": "storyforge-windows-asr", "docs": "/docs"}
|
|
|
|
|
|
def transcribe_file(temp_path: Path, started: float) -> dict[str, object]:
|
|
model = get_model()
|
|
try:
|
|
segments, info = model.transcribe(
|
|
str(temp_path),
|
|
language=resolve_language(),
|
|
beam_size=max(1, BEAM_SIZE),
|
|
vad_filter=VAD_FILTER,
|
|
)
|
|
except Exception as exc:
|
|
if not should_retry_on_cpu(exc):
|
|
raise
|
|
activate_cpu_fallback()
|
|
model = get_model()
|
|
segments, info = model.transcribe(
|
|
str(temp_path),
|
|
language=resolve_language(),
|
|
beam_size=max(1, BEAM_SIZE),
|
|
vad_filter=VAD_FILTER,
|
|
)
|
|
text = "".join(segment.text for segment in segments).strip()
|
|
duration_ms = int((time.perf_counter() - started) * 1000)
|
|
return {
|
|
"text": text,
|
|
"success": bool(text),
|
|
"duration_ms": duration_ms,
|
|
"detected_language": getattr(info, "language", None),
|
|
"detected_language_probability": getattr(info, "language_probability", None),
|
|
"error_message": None if text else "empty transcription",
|
|
}
|
|
|
|
|
|
@app.post("/transcribe", response_model=None)
|
|
async def transcribe(wav: UploadFile = File(...)):
|
|
started = time.perf_counter()
|
|
suffix = Path(wav.filename or "segment.wav").suffix or ".wav"
|
|
with tempfile.NamedTemporaryFile(prefix="storyforge-asr-", suffix=suffix, delete=False) as handle:
|
|
temp_path = Path(handle.name)
|
|
handle.write(await wav.read())
|
|
|
|
try:
|
|
return transcribe_file(temp_path, started)
|
|
except Exception as exc:
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"text": "",
|
|
"success": False,
|
|
"duration_ms": int((time.perf_counter() - started) * 1000),
|
|
"error_message": str(exc),
|
|
},
|
|
)
|
|
finally:
|
|
temp_path.unlink(missing_ok=True)
|