fix: fallback windows asr to cpu when gpu runtime is missing
This commit is contained in:
@@ -4,6 +4,11 @@
|
||||
|
||||
## 2026-04-06
|
||||
|
||||
### Windows ASR GPU 失败时自动回退 CPU
|
||||
|
||||
- Windows `ASR HTTP` 现在在 `auto` 模式下仍会优先尝试 `cuda + int8_float16`,但如果在真正推理阶段命中 `cublas/cudnn/cuda` 运行库缺失,会自动切回 `cpu + int8` 重试,不再把整次转写卡死在 GPU 路径。
|
||||
- 这让“默认优先用 GPU、但当前机器 CUDA 运行库不完整”的场景也能稳定返回结果,同时保留混合中英文自动识别。
|
||||
|
||||
### Windows ASR 默认改成 GPU 优先与自动语言识别
|
||||
|
||||
- Windows `ASR HTTP` 现在默认不再强锁 `zh + cpu + int8`,而是改成:
|
||||
|
||||
@@ -40,6 +40,9 @@ def describe_compute_mode() -> str:
|
||||
|
||||
|
||||
def build_runtime_profiles() -> list[tuple[str, str]]:
|
||||
fallback_profile = getattr(app.state, "runtime_fallback_profile", None)
|
||||
if fallback_profile:
|
||||
return [fallback_profile]
|
||||
device = describe_device_mode()
|
||||
compute = describe_compute_mode()
|
||||
if device != "auto":
|
||||
@@ -49,6 +52,20 @@ def build_runtime_profiles() -> list[tuple[str, str]]:
|
||||
return [("cuda", "int8_float16"), ("cpu", "int8")]
|
||||
|
||||
|
||||
def should_retry_on_cpu(exc: Exception) -> bool:
|
||||
if describe_device_mode() != "auto":
|
||||
return False
|
||||
message = str(exc).lower()
|
||||
return any(token in message for token in ("cublas", "cudnn", "cuda"))
|
||||
|
||||
|
||||
def activate_cpu_fallback() -> None:
|
||||
app.state.runtime_fallback_profile = ("cpu", "int8")
|
||||
app.state.runtime_device = "cpu"
|
||||
app.state.runtime_compute_type = "int8"
|
||||
get_model.cache_clear()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_model():
|
||||
from faster_whisper import WhisperModel
|
||||
@@ -93,15 +110,19 @@ def root() -> dict[str, str]:
|
||||
return {"service": "storyforge-windows-asr", "docs": "/docs"}
|
||||
|
||||
|
||||
@app.post("/transcribe", response_model=None)
|
||||
async def transcribe(wav: UploadFile = File(...)):
|
||||
started = time.perf_counter()
|
||||
suffix = Path(wav.filename or "segment.wav").suffix or ".wav"
|
||||
with tempfile.NamedTemporaryFile(prefix="storyforge-asr-", suffix=suffix, delete=False) as handle:
|
||||
temp_path = Path(handle.name)
|
||||
handle.write(await wav.read())
|
||||
|
||||
def transcribe_file(temp_path: Path, started: float) -> dict[str, object]:
|
||||
model = get_model()
|
||||
try:
|
||||
segments, info = model.transcribe(
|
||||
str(temp_path),
|
||||
language=resolve_language(),
|
||||
beam_size=max(1, BEAM_SIZE),
|
||||
vad_filter=VAD_FILTER,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not should_retry_on_cpu(exc):
|
||||
raise
|
||||
activate_cpu_fallback()
|
||||
model = get_model()
|
||||
segments, info = model.transcribe(
|
||||
str(temp_path),
|
||||
@@ -119,6 +140,18 @@ async def transcribe(wav: UploadFile = File(...)):
|
||||
"detected_language_probability": getattr(info, "language_probability", None),
|
||||
"error_message": None if text else "empty transcription",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/transcribe", response_model=None)
|
||||
async def transcribe(wav: UploadFile = File(...)):
|
||||
started = time.perf_counter()
|
||||
suffix = Path(wav.filename or "segment.wav").suffix or ".wav"
|
||||
with tempfile.NamedTemporaryFile(prefix="storyforge-asr-", suffix=suffix, delete=False) as handle:
|
||||
temp_path = Path(handle.name)
|
||||
handle.write(await wav.read())
|
||||
|
||||
try:
|
||||
return transcribe_file(temp_path, started)
|
||||
except Exception as exc:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -69,3 +70,45 @@ class WindowsAsrHttpTests(unittest.TestCase):
|
||||
self.assertEqual(module.resolve_language(), "zh")
|
||||
self.assertEqual(module.describe_language_mode(), "zh")
|
||||
self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")])
|
||||
|
||||
def test_auto_runtime_falls_back_to_cpu_when_cuda_runtime_is_missing(self) -> None:
|
||||
os.environ.pop("WHISPER_LANGUAGE", None)
|
||||
os.environ.pop("WHISPER_DEVICE", None)
|
||||
os.environ.pop("WHISPER_COMPUTE_TYPE", None)
|
||||
|
||||
module = load_windows_asr_app()
|
||||
|
||||
class FakeInfo:
|
||||
language = "zh"
|
||||
language_probability = 0.88
|
||||
|
||||
class FakeSegment:
|
||||
text = "中英混合 hello world"
|
||||
|
||||
class BrokenGpuModel:
|
||||
def transcribe(self, *_args, **_kwargs):
|
||||
raise RuntimeError("Library cublas64_12.dll is not found or cannot be loaded")
|
||||
|
||||
class CpuModel:
|
||||
def transcribe(self, *_args, **_kwargs):
|
||||
return iter([FakeSegment()]), FakeInfo()
|
||||
|
||||
models = [BrokenGpuModel(), CpuModel()]
|
||||
cache_cleared = []
|
||||
|
||||
def fake_get_model():
|
||||
return models.pop(0)
|
||||
|
||||
fake_get_model.cache_clear = lambda: cache_cleared.append(True)
|
||||
module.get_model = fake_get_model
|
||||
module.app.state.runtime_device = "cuda"
|
||||
module.app.state.runtime_compute_type = "int8_float16"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav") as handle:
|
||||
payload = module.transcribe_file(Path(handle.name), started=0.0)
|
||||
|
||||
self.assertTrue(payload["success"])
|
||||
self.assertEqual(payload["text"], "中英混合 hello world")
|
||||
self.assertEqual(payload["detected_language"], "zh")
|
||||
self.assertEqual(module.app.state.runtime_fallback_profile, ("cpu", "int8"))
|
||||
self.assertTrue(cache_cleared)
|
||||
|
||||
Reference in New Issue
Block a user