feat: auto-detect language and prefer gpu for windows asr
This commit is contained in:
13
CHANGELOG.md
13
CHANGELOG.md
@@ -4,6 +4,19 @@
|
||||
|
||||
## 2026-04-06
|
||||
|
||||
### Windows ASR 默认改成 GPU 优先与自动语言识别
|
||||
|
||||
- Windows `ASR HTTP` 现在默认不再强锁 `zh + cpu + int8`,而是改成:
|
||||
- `WHISPER_DEVICE=auto`
|
||||
- `WHISPER_LANGUAGE=auto`
|
||||
- `WHISPER_COMPUTE_TYPE=auto`
|
||||
- 运行时会优先尝试 `cuda + int8_float16`,如果当前机器没有可用 GPU,再自动回退到 `cpu + int8`。
|
||||
- 转写请求默认不再强制指定语言,这样一句话里中英混说时,会按模型自动识别而不是强压成中文模式。
|
||||
- 健康接口现在也会明确返回:
|
||||
- 配置层 `language/device/compute_type`
|
||||
- 实际加载后的 `active_device/active_compute_type`
|
||||
便于区分“当前策略”和“本轮真实用到的运行模式”。
|
||||
|
||||
### NAS collector 改走服务器本机的 n8n 与火爆视频
|
||||
|
||||
- 新增 `fnOS -> 公网服务器` 的本地转发隧道,把服务器本机 `127.0.0.1:25670/25678` 分别映射到 NAS 的 `19570/19578`。
|
||||
|
||||
@@ -10,9 +10,6 @@ from fastapi import FastAPI, File, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
|
||||
LANGUAGE = os.getenv("WHISPER_LANGUAGE", "zh")
|
||||
DEVICE = os.getenv("WHISPER_DEVICE", "cpu")
|
||||
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
|
||||
BEAM_SIZE = int(os.getenv("WHISPER_BEAM_SIZE", "5"))
|
||||
VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "1").strip().lower() not in {"0", "false", "no"}
|
||||
DOWNLOAD_ROOT = Path(os.getenv("WHISPER_DOWNLOAD_ROOT", str(Path(__file__).resolve().parent / "models-cache")))
|
||||
@@ -20,17 +17,59 @@ DOWNLOAD_ROOT = Path(os.getenv("WHISPER_DOWNLOAD_ROOT", str(Path(__file__).resol
|
||||
app = FastAPI(title="storyforge-windows-asr", version="1.0.0")
|
||||
|
||||
|
||||
def describe_language_mode() -> str:
|
||||
value = (os.getenv("WHISPER_LANGUAGE", "") or "").strip()
|
||||
if not value or value.lower() in {"auto", "detect"}:
|
||||
return "auto"
|
||||
return value
|
||||
|
||||
|
||||
def resolve_language() -> str | None:
|
||||
value = describe_language_mode()
|
||||
return None if value == "auto" else value
|
||||
|
||||
|
||||
def describe_device_mode() -> str:
|
||||
value = (os.getenv("WHISPER_DEVICE", "") or "").strip().lower()
|
||||
return value or "auto"
|
||||
|
||||
|
||||
def describe_compute_mode() -> str:
|
||||
value = (os.getenv("WHISPER_COMPUTE_TYPE", "") or "").strip()
|
||||
return value or "auto"
|
||||
|
||||
|
||||
def build_runtime_profiles() -> list[tuple[str, str]]:
|
||||
device = describe_device_mode()
|
||||
compute = describe_compute_mode()
|
||||
if device != "auto":
|
||||
return [(device, compute if compute != "auto" else "int8")]
|
||||
if compute != "auto":
|
||||
return [("cuda", compute), ("cpu", compute)]
|
||||
return [("cuda", "int8_float16"), ("cpu", "int8")]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_model():
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
DOWNLOAD_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
return WhisperModel(
|
||||
last_error: Exception | None = None
|
||||
for device, compute_type in build_runtime_profiles():
|
||||
try:
|
||||
model = WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=DEVICE,
|
||||
compute_type=COMPUTE_TYPE,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=str(DOWNLOAD_ROOT),
|
||||
)
|
||||
app.state.runtime_device = device
|
||||
app.state.runtime_compute_type = compute_type
|
||||
return model
|
||||
except Exception as exc: # pragma: no cover - exercised on real hosts
|
||||
last_error = exc
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@@ -39,9 +78,11 @@ def health() -> dict[str, object]:
|
||||
"status": "ok",
|
||||
"service": "storyforge-windows-asr",
|
||||
"model_name": MODEL_NAME,
|
||||
"language": LANGUAGE,
|
||||
"device": DEVICE,
|
||||
"compute_type": COMPUTE_TYPE,
|
||||
"language": describe_language_mode(),
|
||||
"device": describe_device_mode(),
|
||||
"compute_type": describe_compute_mode(),
|
||||
"active_device": getattr(app.state, "runtime_device", ""),
|
||||
"active_compute_type": getattr(app.state, "runtime_compute_type", ""),
|
||||
"download_root": str(DOWNLOAD_ROOT),
|
||||
"model_loaded": get_model.cache_info().currsize > 0,
|
||||
}
|
||||
@@ -62,9 +103,9 @@ async def transcribe(wav: UploadFile = File(...)):
|
||||
|
||||
try:
|
||||
model = get_model()
|
||||
segments, _info = model.transcribe(
|
||||
segments, info = model.transcribe(
|
||||
str(temp_path),
|
||||
language=LANGUAGE or None,
|
||||
language=resolve_language(),
|
||||
beam_size=max(1, BEAM_SIZE),
|
||||
vad_filter=VAD_FILTER,
|
||||
)
|
||||
@@ -74,6 +115,8 @@ async def transcribe(wav: UploadFile = File(...)):
|
||||
"text": text,
|
||||
"success": bool(text),
|
||||
"duration_ms": duration_ms,
|
||||
"detected_language": getattr(info, "language", None),
|
||||
"detected_language_probability": getattr(info, "language_probability", None),
|
||||
"error_message": None if text else "empty transcription",
|
||||
}
|
||||
except Exception as exc:
|
||||
|
||||
@@ -13,9 +13,9 @@ $venvPython = Join-Path $venvDir "Scripts\python.exe"
|
||||
& $venvPython -m pip install -r (Join-Path $scriptDir "requirements.txt")
|
||||
|
||||
$env:WHISPER_MODEL = if ($env:WHISPER_MODEL) { $env:WHISPER_MODEL } else { "base" }
|
||||
$env:WHISPER_LANGUAGE = if ($env:WHISPER_LANGUAGE) { $env:WHISPER_LANGUAGE } else { "zh" }
|
||||
$env:WHISPER_DEVICE = if ($env:WHISPER_DEVICE) { $env:WHISPER_DEVICE } else { "cpu" }
|
||||
$env:WHISPER_COMPUTE_TYPE = if ($env:WHISPER_COMPUTE_TYPE) { $env:WHISPER_COMPUTE_TYPE } else { "int8" }
|
||||
$env:WHISPER_LANGUAGE = if ($env:WHISPER_LANGUAGE) { $env:WHISPER_LANGUAGE } else { "" }
|
||||
$env:WHISPER_DEVICE = if ($env:WHISPER_DEVICE) { $env:WHISPER_DEVICE } else { "auto" }
|
||||
$env:WHISPER_COMPUTE_TYPE = if ($env:WHISPER_COMPUTE_TYPE) { $env:WHISPER_COMPUTE_TYPE } else { "" }
|
||||
$env:WHISPER_BEAM_SIZE = if ($env:WHISPER_BEAM_SIZE) { $env:WHISPER_BEAM_SIZE } else { "5" }
|
||||
$env:WHISPER_VAD_FILTER = if ($env:WHISPER_VAD_FILTER) { $env:WHISPER_VAD_FILTER } else { "1" }
|
||||
$env:WHISPER_DOWNLOAD_ROOT = if ($env:WHISPER_DOWNLOAD_ROOT) { $env:WHISPER_DOWNLOAD_ROOT } else { (Join-Path $scriptDir "models-cache") }
|
||||
|
||||
71
tests/test_windows_asr_http.py
Normal file
71
tests/test_windows_asr_http.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
APP_PATH = ROOT / "deploy" / "storyforge-windows-asr-http" / "app.py"
|
||||
|
||||
|
||||
def load_windows_asr_app():
|
||||
spec = importlib.util.spec_from_file_location("storyforge_windows_asr_app", APP_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
class WindowsAsrHttpTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._saved = {key: os.environ.get(key) for key in [
|
||||
"WHISPER_LANGUAGE",
|
||||
"WHISPER_DEVICE",
|
||||
"WHISPER_COMPUTE_TYPE",
|
||||
]}
|
||||
|
||||
def tearDown(self) -> None:
|
||||
for key, value in self._saved.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def test_defaults_prefer_auto_language_and_gpu_first_profiles(self) -> None:
|
||||
os.environ.pop("WHISPER_LANGUAGE", None)
|
||||
os.environ.pop("WHISPER_DEVICE", None)
|
||||
os.environ.pop("WHISPER_COMPUTE_TYPE", None)
|
||||
|
||||
module = load_windows_asr_app()
|
||||
|
||||
self.assertIsNone(module.resolve_language())
|
||||
self.assertEqual(module.describe_language_mode(), "auto")
|
||||
self.assertEqual(
|
||||
module.build_runtime_profiles(),
|
||||
[("cuda", "int8_float16"), ("cpu", "int8")],
|
||||
)
|
||||
|
||||
client = TestClient(module.app)
|
||||
try:
|
||||
payload = client.get("/health").json()
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
self.assertEqual(payload["language"], "auto")
|
||||
self.assertEqual(payload["device"], "auto")
|
||||
self.assertEqual(payload["compute_type"], "auto")
|
||||
|
||||
def test_explicit_runtime_overrides_are_respected(self) -> None:
|
||||
os.environ["WHISPER_LANGUAGE"] = "zh"
|
||||
os.environ["WHISPER_DEVICE"] = "cpu"
|
||||
os.environ["WHISPER_COMPUTE_TYPE"] = "int8"
|
||||
|
||||
module = load_windows_asr_app()
|
||||
|
||||
self.assertEqual(module.resolve_language(), "zh")
|
||||
self.assertEqual(module.describe_language_mode(), "zh")
|
||||
self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")])
|
||||
Reference in New Issue
Block a user