fix: restore windows asr gpu runtime
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import sysconfig
|
||||
import tempfile
|
||||
import time
|
||||
from functools import lru_cache
|
||||
@@ -15,6 +17,7 @@ VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "1").strip().lower() not in {"0", "
|
||||
DOWNLOAD_ROOT = Path(os.getenv("WHISPER_DOWNLOAD_ROOT", str(Path(__file__).resolve().parent / "models-cache")))
|
||||
|
||||
app = FastAPI(title="storyforge-windows-asr", version="1.0.0")
|
||||
_dll_handles: list[object] = []
|
||||
|
||||
|
||||
def describe_language_mode() -> str:
|
||||
@@ -66,8 +69,44 @@ def activate_cpu_fallback() -> None:
|
||||
get_model.cache_clear()
|
||||
|
||||
|
||||
def find_windows_cuda_runtime_dirs(site_packages_root: Path | None = None) -> list[Path]:
|
||||
root = site_packages_root or Path(sysconfig.get_paths()["purelib"])
|
||||
dirs = []
|
||||
for rel in (
|
||||
"nvidia/cublas/bin",
|
||||
"nvidia/cuda_runtime/bin",
|
||||
"nvidia/cuda_nvrtc/bin",
|
||||
"nvidia/cudnn/bin",
|
||||
):
|
||||
path = root / rel
|
||||
if path.exists():
|
||||
dirs.append(path)
|
||||
return dirs
|
||||
|
||||
|
||||
def configure_windows_cuda_runtime() -> None:
|
||||
if sys.platform != "win32":
|
||||
return
|
||||
configured = getattr(app.state, "windows_cuda_runtime_dirs", None)
|
||||
if configured is not None:
|
||||
return
|
||||
runtime_dirs = find_windows_cuda_runtime_dirs()
|
||||
app.state.windows_cuda_runtime_dirs = [str(path) for path in runtime_dirs]
|
||||
if not runtime_dirs:
|
||||
return
|
||||
path_parts = os.environ.get("PATH", "").split(os.pathsep)
|
||||
for runtime_dir in runtime_dirs:
|
||||
runtime_dir_str = str(runtime_dir)
|
||||
if runtime_dir_str not in path_parts:
|
||||
path_parts.insert(0, runtime_dir_str)
|
||||
if hasattr(os, "add_dll_directory"):
|
||||
_dll_handles.append(os.add_dll_directory(runtime_dir_str))
|
||||
os.environ["PATH"] = os.pathsep.join(path_parts)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_model():
|
||||
configure_windows_cuda_runtime()
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
DOWNLOAD_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
@@ -91,6 +130,7 @@ def get_model():
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, object]:
|
||||
configure_windows_cuda_runtime()
|
||||
return {
|
||||
"status": "ok",
|
||||
"service": "storyforge-windows-asr",
|
||||
@@ -102,6 +142,7 @@ def health() -> dict[str, object]:
|
||||
"active_compute_type": getattr(app.state, "runtime_compute_type", ""),
|
||||
"download_root": str(DOWNLOAD_ROOT),
|
||||
"model_loaded": get_model.cache_info().currsize > 0,
|
||||
"windows_cuda_runtime_dirs": getattr(app.state, "windows_cuda_runtime_dirs", []),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -2,3 +2,6 @@ fastapi==0.116.1
|
||||
uvicorn[standard]==0.35.0
|
||||
python-multipart==0.0.20
|
||||
faster-whisper>=1.1,<2
|
||||
nvidia-cublas-cu12; platform_system == "Windows"
|
||||
nvidia-cuda-runtime-cu12; platform_system == "Windows"
|
||||
nvidia-cudnn-cu12; platform_system == "Windows"
|
||||
|
||||
Reference in New Issue
Block a user