fix: fallback windows asr to cpu when gpu runtime is missing
This commit is contained in:
@@ -4,6 +4,11 @@
|
|||||||
|
|
||||||
## 2026-04-06
|
## 2026-04-06
|
||||||
|
|
||||||
|
### Windows ASR GPU 失败时自动回退 CPU
|
||||||
|
|
||||||
|
- Windows `ASR HTTP` 现在在 `auto` 模式下仍会优先尝试 `cuda + int8_float16`,但如果在真正推理阶段命中 `cublas/cudnn/cuda` 运行库缺失,会自动切回 `cpu + int8` 重试,不再把整次转写卡死在 GPU 路径。
|
||||||
|
- 这让“默认优先用 GPU、但当前机器 CUDA 运行库不完整”的场景也能稳定返回结果,同时保留混合中英文自动识别。
|
||||||
|
|
||||||
### Windows ASR 默认改成 GPU 优先与自动语言识别
|
### Windows ASR 默认改成 GPU 优先与自动语言识别
|
||||||
|
|
||||||
- Windows `ASR HTTP` 现在默认不再强锁 `zh + cpu + int8`,而是改成:
|
- Windows `ASR HTTP` 现在默认不再强锁 `zh + cpu + int8`,而是改成:
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ def describe_compute_mode() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def build_runtime_profiles() -> list[tuple[str, str]]:
|
def build_runtime_profiles() -> list[tuple[str, str]]:
|
||||||
|
fallback_profile = getattr(app.state, "runtime_fallback_profile", None)
|
||||||
|
if fallback_profile:
|
||||||
|
return [fallback_profile]
|
||||||
device = describe_device_mode()
|
device = describe_device_mode()
|
||||||
compute = describe_compute_mode()
|
compute = describe_compute_mode()
|
||||||
if device != "auto":
|
if device != "auto":
|
||||||
@@ -49,6 +52,20 @@ def build_runtime_profiles() -> list[tuple[str, str]]:
|
|||||||
return [("cuda", "int8_float16"), ("cpu", "int8")]
|
return [("cuda", "int8_float16"), ("cpu", "int8")]
|
||||||
|
|
||||||
|
|
||||||
|
def should_retry_on_cpu(exc: Exception) -> bool:
|
||||||
|
if describe_device_mode() != "auto":
|
||||||
|
return False
|
||||||
|
message = str(exc).lower()
|
||||||
|
return any(token in message for token in ("cublas", "cudnn", "cuda"))
|
||||||
|
|
||||||
|
|
||||||
|
def activate_cpu_fallback() -> None:
|
||||||
|
app.state.runtime_fallback_profile = ("cpu", "int8")
|
||||||
|
app.state.runtime_device = "cpu"
|
||||||
|
app.state.runtime_compute_type = "int8"
|
||||||
|
get_model.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_model():
|
def get_model():
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
@@ -93,6 +110,38 @@ def root() -> dict[str, str]:
|
|||||||
return {"service": "storyforge-windows-asr", "docs": "/docs"}
|
return {"service": "storyforge-windows-asr", "docs": "/docs"}
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_file(temp_path: Path, started: float) -> dict[str, object]:
|
||||||
|
model = get_model()
|
||||||
|
try:
|
||||||
|
segments, info = model.transcribe(
|
||||||
|
str(temp_path),
|
||||||
|
language=resolve_language(),
|
||||||
|
beam_size=max(1, BEAM_SIZE),
|
||||||
|
vad_filter=VAD_FILTER,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if not should_retry_on_cpu(exc):
|
||||||
|
raise
|
||||||
|
activate_cpu_fallback()
|
||||||
|
model = get_model()
|
||||||
|
segments, info = model.transcribe(
|
||||||
|
str(temp_path),
|
||||||
|
language=resolve_language(),
|
||||||
|
beam_size=max(1, BEAM_SIZE),
|
||||||
|
vad_filter=VAD_FILTER,
|
||||||
|
)
|
||||||
|
text = "".join(segment.text for segment in segments).strip()
|
||||||
|
duration_ms = int((time.perf_counter() - started) * 1000)
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"success": bool(text),
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"detected_language": getattr(info, "language", None),
|
||||||
|
"detected_language_probability": getattr(info, "language_probability", None),
|
||||||
|
"error_message": None if text else "empty transcription",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/transcribe", response_model=None)
|
@app.post("/transcribe", response_model=None)
|
||||||
async def transcribe(wav: UploadFile = File(...)):
|
async def transcribe(wav: UploadFile = File(...)):
|
||||||
started = time.perf_counter()
|
started = time.perf_counter()
|
||||||
@@ -102,23 +151,7 @@ async def transcribe(wav: UploadFile = File(...)):
|
|||||||
handle.write(await wav.read())
|
handle.write(await wav.read())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model()
|
return transcribe_file(temp_path, started)
|
||||||
segments, info = model.transcribe(
|
|
||||||
str(temp_path),
|
|
||||||
language=resolve_language(),
|
|
||||||
beam_size=max(1, BEAM_SIZE),
|
|
||||||
vad_filter=VAD_FILTER,
|
|
||||||
)
|
|
||||||
text = "".join(segment.text for segment in segments).strip()
|
|
||||||
duration_ms = int((time.perf_counter() - started) * 1000)
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"success": bool(text),
|
|
||||||
"duration_ms": duration_ms,
|
|
||||||
"detected_language": getattr(info, "language", None),
|
|
||||||
"detected_language_probability": getattr(info, "language_probability", None),
|
|
||||||
"error_message": None if text else "empty transcription",
|
|
||||||
}
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -69,3 +70,45 @@ class WindowsAsrHttpTests(unittest.TestCase):
|
|||||||
self.assertEqual(module.resolve_language(), "zh")
|
self.assertEqual(module.resolve_language(), "zh")
|
||||||
self.assertEqual(module.describe_language_mode(), "zh")
|
self.assertEqual(module.describe_language_mode(), "zh")
|
||||||
self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")])
|
self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")])
|
||||||
|
|
||||||
|
def test_auto_runtime_falls_back_to_cpu_when_cuda_runtime_is_missing(self) -> None:
|
||||||
|
os.environ.pop("WHISPER_LANGUAGE", None)
|
||||||
|
os.environ.pop("WHISPER_DEVICE", None)
|
||||||
|
os.environ.pop("WHISPER_COMPUTE_TYPE", None)
|
||||||
|
|
||||||
|
module = load_windows_asr_app()
|
||||||
|
|
||||||
|
class FakeInfo:
|
||||||
|
language = "zh"
|
||||||
|
language_probability = 0.88
|
||||||
|
|
||||||
|
class FakeSegment:
|
||||||
|
text = "中英混合 hello world"
|
||||||
|
|
||||||
|
class BrokenGpuModel:
|
||||||
|
def transcribe(self, *_args, **_kwargs):
|
||||||
|
raise RuntimeError("Library cublas64_12.dll is not found or cannot be loaded")
|
||||||
|
|
||||||
|
class CpuModel:
|
||||||
|
def transcribe(self, *_args, **_kwargs):
|
||||||
|
return iter([FakeSegment()]), FakeInfo()
|
||||||
|
|
||||||
|
models = [BrokenGpuModel(), CpuModel()]
|
||||||
|
cache_cleared = []
|
||||||
|
|
||||||
|
def fake_get_model():
|
||||||
|
return models.pop(0)
|
||||||
|
|
||||||
|
fake_get_model.cache_clear = lambda: cache_cleared.append(True)
|
||||||
|
module.get_model = fake_get_model
|
||||||
|
module.app.state.runtime_device = "cuda"
|
||||||
|
module.app.state.runtime_compute_type = "int8_float16"
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav") as handle:
|
||||||
|
payload = module.transcribe_file(Path(handle.name), started=0.0)
|
||||||
|
|
||||||
|
self.assertTrue(payload["success"])
|
||||||
|
self.assertEqual(payload["text"], "中英混合 hello world")
|
||||||
|
self.assertEqual(payload["detected_language"], "zh")
|
||||||
|
self.assertEqual(module.app.state.runtime_fallback_profile, ("cpu", "int8"))
|
||||||
|
self.assertTrue(cache_cleared)
|
||||||
|
|||||||
Reference in New Issue
Block a user