From f53a4b4461dba53c123db634f51370cc0094c532 Mon Sep 17 00:00:00 2001 From: kris Date: Mon, 6 Apr 2026 11:29:12 +0800 Subject: [PATCH] fix: fallback windows asr to cpu when gpu runtime is missing --- CHANGELOG.md | 5 ++ deploy/storyforge-windows-asr-http/app.py | 67 +++++++++++++++++------ tests/test_windows_asr_http.py | 43 +++++++++++++++ 3 files changed, 98 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f6ddb3..652d0f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ ## 2026-04-06 +### Windows ASR GPU 失败时自动回退 CPU + +- Windows `ASR HTTP` 现在在 `auto` 模式下仍会优先尝试 `cuda + int8_float16`,但如果在真正推理阶段命中 `cublas/cudnn/cuda` 运行库缺失,会自动切回 `cpu + int8` 重试,不再把整次转写卡死在 GPU 路径。 +- 这让“默认优先用 GPU、但当前机器 CUDA 运行库不完整”的场景也能稳定返回结果,同时保留混合中英文自动识别。 + ### Windows ASR 默认改成 GPU 优先与自动语言识别 - Windows `ASR HTTP` 现在默认不再强锁 `zh + cpu + int8`,而是改成: diff --git a/deploy/storyforge-windows-asr-http/app.py b/deploy/storyforge-windows-asr-http/app.py index ed4996f..a3b4b28 100644 --- a/deploy/storyforge-windows-asr-http/app.py +++ b/deploy/storyforge-windows-asr-http/app.py @@ -40,6 +40,9 @@ def describe_compute_mode() -> str: def build_runtime_profiles() -> list[tuple[str, str]]: + fallback_profile = getattr(app.state, "runtime_fallback_profile", None) + if fallback_profile: + return [fallback_profile] device = describe_device_mode() compute = describe_compute_mode() if device != "auto": @@ -49,6 +52,20 @@ def build_runtime_profiles() -> list[tuple[str, str]]: return [("cuda", "int8_float16"), ("cpu", "int8")] +def should_retry_on_cpu(exc: Exception) -> bool: + if describe_device_mode() != "auto": + return False + message = str(exc).lower() + return any(token in message for token in ("cublas", "cudnn", "cuda")) + + +def activate_cpu_fallback() -> None: + app.state.runtime_fallback_profile = ("cpu", "int8") + app.state.runtime_device = "cpu" + app.state.runtime_compute_type = "int8" + get_model.cache_clear() + + @lru_cache(maxsize=1) def get_model(): from faster_whisper import WhisperModel @@ -93,6 +110,38 @@ def root() -> dict[str, str]: return {"service": "storyforge-windows-asr", "docs": "/docs"} +def transcribe_file(temp_path: Path, started: float) -> dict[str, object]: + model = get_model() + try: + segments, info = model.transcribe( + str(temp_path), + language=resolve_language(), + beam_size=max(1, BEAM_SIZE), + vad_filter=VAD_FILTER, + ) + except Exception as exc: + if not should_retry_on_cpu(exc): + raise + activate_cpu_fallback() + model = get_model() + segments, info = model.transcribe( + str(temp_path), + language=resolve_language(), + beam_size=max(1, BEAM_SIZE), + vad_filter=VAD_FILTER, + ) + text = "".join(segment.text for segment in segments).strip() + duration_ms = int((time.perf_counter() - started) * 1000) + return { + "text": text, + "success": bool(text), + "duration_ms": duration_ms, + "detected_language": getattr(info, "language", None), + "detected_language_probability": getattr(info, "language_probability", None), + "error_message": None if text else "empty transcription", + } + + @app.post("/transcribe", response_model=None) async def transcribe(wav: UploadFile = File(...)): started = time.perf_counter() @@ -102,23 +151,7 @@ async def transcribe(wav: UploadFile = File(...)): handle.write(await wav.read()) try: - model = get_model() - segments, info = model.transcribe( - str(temp_path), - language=resolve_language(), - beam_size=max(1, BEAM_SIZE), - vad_filter=VAD_FILTER, - ) - text = "".join(segment.text for segment in segments).strip() - duration_ms = int((time.perf_counter() - started) * 1000) - return { - "text": text, - "success": bool(text), - "duration_ms": duration_ms, - "detected_language": getattr(info, "language", None), - "detected_language_probability": getattr(info, "language_probability", None), - "error_message": None if text else "empty transcription", - } + return transcribe_file(temp_path, started) except Exception as exc: return JSONResponse( status_code=500, diff --git a/tests/test_windows_asr_http.py b/tests/test_windows_asr_http.py index 685d6b9..be09519 100644 --- a/tests/test_windows_asr_http.py +++ b/tests/test_windows_asr_http.py @@ -2,6 +2,7 @@ from __future__ import annotations import importlib.util import os +import tempfile import unittest from pathlib import Path @@ -69,3 +70,45 @@ class WindowsAsrHttpTests(unittest.TestCase): self.assertEqual(module.resolve_language(), "zh") self.assertEqual(module.describe_language_mode(), "zh") self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")]) + + def test_auto_runtime_falls_back_to_cpu_when_cuda_runtime_is_missing(self) -> None: + os.environ.pop("WHISPER_LANGUAGE", None) + os.environ.pop("WHISPER_DEVICE", None) + os.environ.pop("WHISPER_COMPUTE_TYPE", None) + + module = load_windows_asr_app() + + class FakeInfo: + language = "zh" + language_probability = 0.88 + + class FakeSegment: + text = "中英混合 hello world" + + class BrokenGpuModel: + def transcribe(self, *_args, **_kwargs): + raise RuntimeError("Library cublas64_12.dll is not found or cannot be loaded") + + class CpuModel: + def transcribe(self, *_args, **_kwargs): + return iter([FakeSegment()]), FakeInfo() + + models = [BrokenGpuModel(), CpuModel()] + cache_cleared = [] + + def fake_get_model(): + return models.pop(0) + + fake_get_model.cache_clear = lambda: cache_cleared.append(True) + module.get_model = fake_get_model + module.app.state.runtime_device = "cuda" + module.app.state.runtime_compute_type = "int8_float16" + + with tempfile.NamedTemporaryFile(suffix=".wav") as handle: + payload = module.transcribe_file(Path(handle.name), started=0.0) + + self.assertTrue(payload["success"]) + self.assertEqual(payload["text"], "中英混合 hello world") + self.assertEqual(payload["detected_language"], "zh") + self.assertEqual(module.app.state.runtime_fallback_profile, ("cpu", "int8")) + self.assertTrue(cache_cleared)