fix: fallback windows asr to cpu when gpu runtime is missing
Some checks failed
StoryForge CI / Baseline checks (push) Has been cancelled
StoryForge CI / Backend tests (push) Has been cancelled
StoryForge CI / Web tests (push) Has been cancelled

This commit is contained in:
kris
2026-04-06 11:29:12 +08:00
parent 4ff7efb61c
commit f53a4b4461
3 changed files with 98 additions and 17 deletions

View File

@@ -4,6 +4,11 @@
## 2026-04-06
### Windows ASR GPU 失败时自动回退 CPU
- Windows `ASR HTTP` 现在在 `auto` 模式下仍会优先尝试 `cuda + int8_float16`,但如果在真正推理阶段命中 `cublas/cudnn/cuda` 运行库缺失,会自动切回 `cpu + int8` 重试,不再把整次转写卡死在 GPU 路径。
- 这让“默认优先用 GPU、但当前机器 CUDA 运行库不完整”的场景也能稳定返回结果,同时保留混合中英文自动识别。
### Windows ASR 默认改成 GPU 优先与自动语言识别
- Windows `ASR HTTP` 现在默认不再强锁 `zh + cpu + int8`,而是改成:

View File

@@ -40,6 +40,9 @@ def describe_compute_mode() -> str:
def build_runtime_profiles() -> list[tuple[str, str]]:
fallback_profile = getattr(app.state, "runtime_fallback_profile", None)
if fallback_profile:
return [fallback_profile]
device = describe_device_mode()
compute = describe_compute_mode()
if device != "auto":
@@ -49,6 +52,20 @@ def build_runtime_profiles() -> list[tuple[str, str]]:
return [("cuda", "int8_float16"), ("cpu", "int8")]
def should_retry_on_cpu(exc: Exception) -> bool:
if describe_device_mode() != "auto":
return False
message = str(exc).lower()
return any(token in message for token in ("cublas", "cudnn", "cuda"))
def activate_cpu_fallback() -> None:
app.state.runtime_fallback_profile = ("cpu", "int8")
app.state.runtime_device = "cpu"
app.state.runtime_compute_type = "int8"
get_model.cache_clear()
@lru_cache(maxsize=1)
def get_model():
from faster_whisper import WhisperModel
@@ -93,6 +110,38 @@ def root() -> dict[str, str]:
return {"service": "storyforge-windows-asr", "docs": "/docs"}
def transcribe_file(temp_path: Path, started: float) -> dict[str, object]:
model = get_model()
try:
segments, info = model.transcribe(
str(temp_path),
language=resolve_language(),
beam_size=max(1, BEAM_SIZE),
vad_filter=VAD_FILTER,
)
except Exception as exc:
if not should_retry_on_cpu(exc):
raise
activate_cpu_fallback()
model = get_model()
segments, info = model.transcribe(
str(temp_path),
language=resolve_language(),
beam_size=max(1, BEAM_SIZE),
vad_filter=VAD_FILTER,
)
text = "".join(segment.text for segment in segments).strip()
duration_ms = int((time.perf_counter() - started) * 1000)
return {
"text": text,
"success": bool(text),
"duration_ms": duration_ms,
"detected_language": getattr(info, "language", None),
"detected_language_probability": getattr(info, "language_probability", None),
"error_message": None if text else "empty transcription",
}
@app.post("/transcribe", response_model=None)
async def transcribe(wav: UploadFile = File(...)):
started = time.perf_counter()
@@ -102,23 +151,7 @@ async def transcribe(wav: UploadFile = File(...)):
handle.write(await wav.read())
try:
model = get_model()
segments, info = model.transcribe(
str(temp_path),
language=resolve_language(),
beam_size=max(1, BEAM_SIZE),
vad_filter=VAD_FILTER,
)
text = "".join(segment.text for segment in segments).strip()
duration_ms = int((time.perf_counter() - started) * 1000)
return {
"text": text,
"success": bool(text),
"duration_ms": duration_ms,
"detected_language": getattr(info, "language", None),
"detected_language_probability": getattr(info, "language_probability", None),
"error_message": None if text else "empty transcription",
}
return transcribe_file(temp_path, started)
except Exception as exc:
return JSONResponse(
status_code=500,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import importlib.util
import os
import tempfile
import unittest
from pathlib import Path
@@ -69,3 +70,45 @@ class WindowsAsrHttpTests(unittest.TestCase):
self.assertEqual(module.resolve_language(), "zh")
self.assertEqual(module.describe_language_mode(), "zh")
self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")])
def test_auto_runtime_falls_back_to_cpu_when_cuda_runtime_is_missing(self) -> None:
os.environ.pop("WHISPER_LANGUAGE", None)
os.environ.pop("WHISPER_DEVICE", None)
os.environ.pop("WHISPER_COMPUTE_TYPE", None)
module = load_windows_asr_app()
class FakeInfo:
language = "zh"
language_probability = 0.88
class FakeSegment:
text = "中英混合 hello world"
class BrokenGpuModel:
def transcribe(self, *_args, **_kwargs):
raise RuntimeError("Library cublas64_12.dll is not found or cannot be loaded")
class CpuModel:
def transcribe(self, *_args, **_kwargs):
return iter([FakeSegment()]), FakeInfo()
models = [BrokenGpuModel(), CpuModel()]
cache_cleared = []
def fake_get_model():
return models.pop(0)
fake_get_model.cache_clear = lambda: cache_cleared.append(True)
module.get_model = fake_get_model
module.app.state.runtime_device = "cuda"
module.app.state.runtime_compute_type = "int8_float16"
with tempfile.NamedTemporaryFile(suffix=".wav") as handle:
payload = module.transcribe_file(Path(handle.name), started=0.0)
self.assertTrue(payload["success"])
self.assertEqual(payload["text"], "中英混合 hello world")
self.assertEqual(payload["detected_language"], "zh")
self.assertEqual(module.app.state.runtime_fallback_profile, ("cpu", "int8"))
self.assertTrue(cache_cleared)